diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 69c099a262..4da070bdbf 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -63,7 +63,8 @@ pnpm analyze-component --review ### File Naming -- Test files: `ComponentName.spec.tsx` (same directory as component) +- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory +- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`. - Integration tests: `web/__tests__/` directory ## Test Structure Template diff --git a/.agents/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx index 6b7803bd4b..ff38f88d23 100644 --- a/.agents/skills/frontend-testing/assets/component-test.template.tsx +++ b/.agents/skills/frontend-testing/assets/component-test.template.tsx @@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event' // Router (if component uses useRouter, usePathname, useSearchParams) // WHY: Isolates tests from Next.js routing, enables testing navigation behavior // const mockPush = vi.fn() -// vi.mock('next/navigation', () => ({ +// vi.mock('@/next/navigation', () => ({ // useRouter: () => ({ push: mockPush }), // usePathname: () => '/test-path', // })) diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 0000000000..15c697730a --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: true + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 54702c914a..24af948732 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,10 +4,9 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1 + uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 with: - node-version-file: "./web/.nvmrc" + working-directory: web + node-version-file: .nvmrc cache: true - run-install: | - - cwd: ./web - args: ['--frozen-lockfile'] + run-install: true diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml index c0d1818691..b0f0a36bc9 100644 --- a/.github/workflows/anti-slop.yml +++ b/.github/workflows/anti-slop.yml @@ -12,7 +12,7 @@ jobs: anti-slop: runs-on: ubuntu-latest steps: - - uses: peakoss/anti-slop@v0 + - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1 with: github-token: ${{ secrets.GITHUB_TOKEN }} close-pr: false diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 12d7ff33c7..6b87946221 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -2,6 +2,12 @@ name: Run Pytest on: workflow_call: + secrets: + CODECOV_TOKEN: + required: false + +permissions: + contents: read concurrency: group: api-tests-${{ github.head_ref || github.run_id }} @@ -11,6 +17,8 @@ jobs: test: name: API Tests runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: run: shell: bash @@ -24,10 +32,11 @@ jobs: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -79,21 +88,12 @@ jobs: api/tests/test_containers_integration_tests \ api/tests/unit_tests - - name: Coverage Summary - run: | - set -x - # Extract coverage percentage and create a summary - TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])') - - # Create a detailed coverage summary - echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY - echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY - { - echo "" - echo "
File-level coverage (click to expand)" - echo "" - echo '```' - uv run --project api coverage report -m - echo '```' - echo "
" - } >> $GITHUB_STEP_SUMMARY + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }} + uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 + with: + files: ./coverage.xml + disable_search: true + flags: api + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 73ca94f98f..be6186980e 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -39,7 +39,7 @@ jobs: with: python-version: "3.11" - - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Generate Docker Compose if: steps.docker-compose-changes.outputs.any_changed == 'true' @@ -94,11 +94,6 @@ jobs: find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete - # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - - name: mdformat - run: | - uvx --python 3.13 mdformat . --exclude ".agents/skills/**" - - name: Setup web environment if: steps.web-changes.outputs.any_changed == 'true' uses: ./.github/actions/setup-web diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index c567a4bfe0..ffb9734e48 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -19,7 +19,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: "3.12" @@ -69,7 +69,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index a19cb50abc..69023c24cc 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -56,16 +56,14 @@ jobs: needs: check-changes if: needs.check-changes.outputs.api-changed == 'true' uses: ./.github/workflows/api-tests.yml + secrets: inherit web-tests: name: Web Tests needs: check-changes if: needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml - with: - base_sha: ${{ github.event.before || github.event.pull_request.base.sha }} - diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }} - head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }} + secrets: inherit style-check: name: Style Check diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index f50df229d5..a00f469bbe 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5e037d2541..23ae36f7b1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: false python-version: "3.12" @@ -84,20 +84,20 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Restore ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' + id: eslint-cache-restore + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: | - vp run lint:ci - # pnpm run lint:report - # continue-on-error: true - - # - name: Annotate Code - # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' - # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae - # with: - # eslint-report: web/eslint_report.json - # github-token: ${{ secrets.GITHUB_TOKEN }} + run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' @@ -114,6 +114,13 @@ jobs: working-directory: ./web run: vp run knip + - name: Save ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 9af6649328..1869254295 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -120,7 +120,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72 + uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 0b771c1af7..f45f2137d6 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -31,7 +31,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index be2595a599..d40cd4bfeb 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -2,16 +2,9 @@ name: Web Tests on: workflow_call: - inputs: - base_sha: + secrets: + CODECOV_TOKEN: required: false - type: string - diff_range_mode: - required: false - type: string - head_sha: - required: false - type: string permissions: contents: read @@ -63,7 +56,7 @@ jobs: needs: [test] runs-on: ubuntu-latest env: - VITEST_COVERAGE_SCOPE: app-components + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: run: shell: bash @@ -87,52 +80,16 @@ jobs: merge-multiple: true - name: Merge reports - run: vp test --merge-reports --reporter=json --reporter=agent --coverage + run: vp test --merge-reports --coverage --silent=passed-only - - name: Report app/components baseline coverage - run: node ./scripts/report-components-coverage-baseline.mjs - - - name: Report app/components test touch - env: - BASE_SHA: ${{ inputs.base_sha }} - DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }} - HEAD_SHA: ${{ inputs.head_sha }} - run: node ./scripts/report-components-test-touch.mjs - - - name: Check app/components pure diff coverage - env: - BASE_SHA: ${{ inputs.base_sha }} - DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }} - HEAD_SHA: ${{ inputs.head_sha }} - run: node ./scripts/check-components-diff-coverage.mjs - - - name: Check Coverage Summary - if: always() - id: coverage-summary - run: | - set -eo pipefail - - COVERAGE_FILE="coverage/coverage-final.json" - COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json" - - if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then - echo "has_coverage=true" >> "$GITHUB_OUTPUT" - exit 0 - fi - - echo "has_coverage=false" >> "$GITHUB_OUTPUT" - echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY" - echo "" >> "$GITHUB_STEP_SUMMARY" - echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY" - - - name: Upload Coverage Artifact - if: steps.coverage-summary.outputs.has_coverage == 'true' - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 with: - name: web-coverage-report - path: web/coverage - retention-days: 30 - if-no-files-found: error + directory: web/coverage + flags: web + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} web-build: name: Web Build diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d7f007af67..775401bfa5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. + +## Automated Agent Contributions + +> [!NOTE] +> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/api/.env.example b/api/.env.example index 40e1c2dfdf..9672a99d55 100644 --- a/api/.env.example +++ b/api/.env.example @@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # Upstash configuration UPSTASH_VECTOR_URL=your-server-url diff --git a/api/commands/plugin.py b/api/commands/plugin.py index 2dfbd73b3a..c34391025a 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -1,9 +1,11 @@ import json import logging -from typing import Any +from typing import Any, cast import click from pydantic import TypeAdapter +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult from configs import dify_config from core.helper import encrypter @@ -48,14 +50,15 @@ def setup_system_tool_oauth_client(provider, client_params): click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) return - deleted_count = ( - db.session.query(ToolOAuthSystemClient) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) + deleted_count = cast( + CursorResult, + db.session.execute( + delete(ToolOAuthSystemClient).where( + ToolOAuthSystemClient.provider == provider_name, + ToolOAuthSystemClient.plugin_id == plugin_id, + ) + ), + ).rowcount if deleted_count > 0: click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) @@ -97,14 +100,15 @@ def setup_system_trigger_oauth_client(provider, client_params): click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) return - deleted_count = ( - db.session.query(TriggerOAuthSystemClient) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) + deleted_count = cast( + CursorResult, + db.session.execute( + delete(TriggerOAuthSystemClient).where( + TriggerOAuthSystemClient.provider == provider_name, + TriggerOAuthSystemClient.plugin_id == plugin_id, + ) + ), + ).rowcount if deleted_count > 0: click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) @@ -139,14 +143,15 @@ def setup_datasource_oauth_client(provider, client_params): return click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) - deleted_count = ( - db.session.query(DatasourceOauthParamConfig) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) + deleted_count = cast( + CursorResult, + db.session.execute( + delete(DatasourceOauthParamConfig).where( + DatasourceOauthParamConfig.provider == provider_name, + DatasourceOauthParamConfig.plugin_id == plugin_id, + ) + ), + ).rowcount if deleted_count > 0: click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) @@ -192,7 +197,9 @@ def transform_datasource_credentials(environment: str): # deal notion credentials deal_notion_count = 0 - notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() + notion_credentials = db.session.scalars( + select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion") + ).all() if notion_credentials: notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} for notion_credential in notion_credentials: @@ -201,7 +208,7 @@ def transform_datasource_credentials(environment: str): notion_credentials_tenant_mapping[tenant_id] = [] notion_credentials_tenant_mapping[tenant_id].append(notion_credential) for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) if not tenant: continue try: @@ -250,7 +257,9 @@ def transform_datasource_credentials(environment: str): db.session.commit() # deal firecrawl credentials deal_firecrawl_count = 0 - firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() + firecrawl_credentials = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl") + ).all() if firecrawl_credentials: firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} for firecrawl_credential in firecrawl_credentials: @@ -259,7 +268,7 @@ def transform_datasource_credentials(environment: str): firecrawl_credentials_tenant_mapping[tenant_id] = [] firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) if not tenant: continue try: @@ -312,7 +321,9 @@ def transform_datasource_credentials(environment: str): db.session.commit() # deal jina credentials deal_jina_count = 0 - jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all() + jina_credentials = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader") + ).all() if jina_credentials: jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} for jina_credential in jina_credentials: @@ -321,7 +332,7 @@ def transform_datasource_credentials(environment: str): jina_credentials_tenant_mapping[tenant_id] = [] jina_credentials_tenant_mapping[tenant_id].append(jina_credential) for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) if not tenant: continue try: diff --git a/api/commands/storage.py b/api/commands/storage.py index fa890a855a..f23b17680a 100644 --- a/api/commands/storage.py +++ b/api/commands/storage.py @@ -1,7 +1,10 @@ import json +from typing import cast import click import sqlalchemy as sa +from sqlalchemy import update +from sqlalchemy.engine import CursorResult from configs import dify_config from extensions.ext_database import db @@ -740,14 +743,17 @@ def migrate_oss( else: try: source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL - updated = ( - db.session.query(UploadFile) - .where( - UploadFile.storage_type == source_storage_type, - UploadFile.key.in_(copied_upload_file_keys), - ) - .update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False) - ) + updated = cast( + CursorResult, + db.session.execute( + update(UploadFile) + .where( + UploadFile.storage_type == source_storage_type, + UploadFile.key.in_(copied_upload_file_keys), + ) + .values(storage_type=dify_config.STORAGE_TYPE) + ), + ).rowcount db.session.commit() click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green")) except Exception as e: diff --git a/api/commands/system.py b/api/commands/system.py index 604f0e34d0..39b2e991ed 100644 --- a/api/commands/system.py +++ b/api/commands/system.py @@ -2,6 +2,7 @@ import logging import click import sqlalchemy as sa +from sqlalchemy import delete, select, update from sqlalchemy.orm import sessionmaker from configs import dify_config @@ -41,7 +42,7 @@ def reset_encrypt_key_pair(): click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) return with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - tenants = session.query(Tenant).all() + tenants = session.scalars(select(Tenant)).all() for tenant in tenants: if not tenant: click.echo(click.style("No workspaces found. Run /install first.", fg="red")) @@ -49,8 +50,8 @@ def reset_encrypt_key_pair(): tenant.encrypt_public_key = generate_key_pair(tenant.id) - session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() - session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() + session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id)) + session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id)) click.echo( click.style( @@ -93,7 +94,7 @@ def convert_to_agent_apps(): app_id = str(i.id) if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.scalar(select(App).where(App.id == app_id)) if app is not None: apps.append(app) @@ -108,8 +109,8 @@ def convert_to_agent_apps(): db.session.commit() # update conversation mode to agent - db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT} + db.session.execute( + update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT) ) db.session.commit() @@ -177,7 +178,7 @@ where sites.id is null limit 1000""" continue try: - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.scalar(select(App).where(App.id == app_id)) if not app: logger.info("App %s not found", app_id) continue diff --git a/api/commands/vector.py b/api/commands/vector.py index 52ce26c26d..cb7eb7c452 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,6 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -41,14 +42,13 @@ def migrate_annotation_vector_database(): # get apps info per_page = 50 with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - apps = ( - session.query(App) + apps = session.scalars( + select(App) .where(App.status == "normal") .order_by(App.created_at.desc()) .limit(per_page) .offset((page - 1) * per_page) - .all() - ) + ).all() if not apps: break except SQLAlchemyError: @@ -63,8 +63,8 @@ def migrate_annotation_vector_database(): try: click.echo(f"Creating app annotation index: {app.id}") with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() + app_annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1) ) if not app_annotation_setting: @@ -72,10 +72,10 @@ def migrate_annotation_vector_database(): click.echo(f"App annotation setting disabled: {app.id}") continue # get dataset_collection_binding info - dataset_collection_binding = ( - session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) - .first() + dataset_collection_binding = session.scalar( + select(DatasetCollectionBinding).where( + DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id + ) ) if not dataset_collection_binding: click.echo(f"App annotation collection binding not found: {app.id}") @@ -86,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -178,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -205,11 +207,11 @@ def migrate_knowledge_vector_database(): collection_name = Dataset.gen_collection_name_by_id(dataset_id) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == dataset.collection_binding_id) - .one_or_none() - ) + dataset_collection_binding = db.session.execute( + select(DatasetCollectionBinding).where( + DatasetCollectionBinding.id == dataset.collection_binding_id + ) + ).scalar_one_or_none() if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: @@ -270,7 +272,7 @@ def migrate_knowledge_vector_database(): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == "hierarchical_model": + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -334,7 +336,7 @@ def add_qdrant_index(field: str): create_count = 0 try: - bindings = db.session.query(DatasetCollectionBinding).all() + bindings = db.session.scalars(select(DatasetCollectionBinding)).all() if not bindings: click.echo(click.style("No dataset collection bindings found.", fg="red")) return @@ -421,10 +423,10 @@ def old_metadata_migration(): if field.value == key: break else: - dataset_metadata = ( - db.session.query(DatasetMetadata) + dataset_metadata = db.session.scalar( + select(DatasetMetadata) .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) - .first() + .limit(1) ) if not dataset_metadata: dataset_metadata = DatasetMetadata( @@ -436,7 +438,7 @@ def old_metadata_migration(): ) db.session.add(dataset_metadata) db.session.flush() - dataset_metadata_binding = DatasetMetadataBinding( + dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding( tenant_id=document.tenant_id, dataset_id=document.dataset_id, metadata_id=dataset_metadata.id, @@ -445,14 +447,14 @@ def old_metadata_migration(): ) db.session.add(dataset_metadata_binding) else: - dataset_metadata_binding = ( - db.session.query(DatasetMetadataBinding) # type: ignore + dataset_metadata_binding = db.session.scalar( + select(DatasetMetadataBinding) .where( DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id, ) - .first() + .limit(1) ) if not dataset_metadata_binding: dataset_metadata_binding = DatasetMetadataBinding( diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 8f956745b1..c8e4f7309f 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings): description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)", default="COARSE_MODE", ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field( + description="Auto build row count increment threshold (default is 500)", + default=500, + ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field( + description="Auto build row count increment ratio threshold (default is 0.05)", + default=0.05, + ) + + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field( + description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)", + default=300, + ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b6d1df319e..783cb5c444 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,7 +1,7 @@ import flask_restx from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -9,6 +9,7 @@ from extensions.ext_database import db from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset +from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache @@ -33,16 +34,10 @@ api_key_list_model = console_ns.model( def _get_resource(resource_id, tenant_id, resource_model): - if resource_model == App: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() - else: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() + with Session(db.engine) as session: + resource = session.execute( + select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) + ).scalar_one_or_none() if resource is None: flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") @@ -53,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None @@ -80,10 +75,13 @@ class BaseApiKeyListResource(Resource): resource_id = str(resource_id) _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) - current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .count() + current_key_count: int = ( + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -94,6 +92,7 @@ class BaseApiKeyListResource(Resource): ) key = ApiToken.generate_api_key(self.token_prefix or "", 24) + assert self.resource_type is not None, "resource_type must be set" api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_tenant_id @@ -107,7 +106,7 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None @@ -119,14 +118,14 @@ class BaseApiKeyResource(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -137,7 +136,7 @@ class BaseApiKeyResource(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() return {"result": "success"}, 204 @@ -162,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" token_prefix = "app-" @@ -178,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource): """Delete an API key for an app""" return super().delete(resource_id, api_key_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" @@ -202,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" token_prefix = "ds-" @@ -218,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6..7e41260eeb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5eb61493c3..d329d22309 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -5,7 +5,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import selectinload from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -376,8 +376,12 @@ class CompletionConversationApi(Resource): # FIXME, the type ignore in this file if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + query = ( + query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type] + .join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + .distinct() ) elif args.annotation_status == "not_annotated": query = ( @@ -454,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -511,8 +513,12 @@ class ChatConversationApi(Resource): match args.annotation_status: case "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + query = ( + query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type] + .join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + .distinct() ) case "not_annotated": query = ( @@ -587,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bb..442d0d2324 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 2025048e09..412fc8795a 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,18 +99,18 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() description = payload.description - if description is None: - pass - elif not description: + if description is None or not description: server.description = app_model.description or "" else: server.description = description + server.name = app_model.name + server.parameters = json.dumps(payload.parameters, ensure_ascii=False) if payload.status: try: @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 3beea2a385..736e7dbe17 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -4,7 +4,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -30,6 +30,7 @@ from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required +from models.enums import FeedbackFromSource, FeedbackRating from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -243,27 +244,25 @@ class ChatMessageListApi(Resource): def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) - .first() + .limit(1) ) if not conversation: raise NotFound("Conversation Not Exists.") if args.first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == args.first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1) ) if not first_message: raise NotFound("First message not found") - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -271,16 +270,14 @@ class ChatMessageListApi(Resource): ) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() # Initialize has_more based on whether we have a full page if len(history_messages) == args.limit: @@ -325,7 +322,9 @@ class MessageFeedbackApi(Resource): message_id = str(args.message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -335,7 +334,7 @@ class MessageFeedbackApi(Resource): if not args.rating and feedback: db.session.delete(feedback) elif args.rating and feedback: - feedback.rating = args.rating + feedback.rating = FeedbackRating(args.rating) feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") @@ -347,9 +346,9 @@ class MessageFeedbackApi(Resource): app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=rating_value, + rating=FeedbackRating(rating_value), content=args.content, - from_source="admin", + from_source=FeedbackFromSource.ADMIN, from_account_id=current_user.id, ) db.session.add(feedback) @@ -374,7 +373,9 @@ class MessageAnnotationCountApi(Resource): @login_required @account_initialization_required def get(self, app_model): - count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() + count = db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id) + ) return {"count": count} @@ -478,7 +479,9 @@ class MessageApi(Resource): def get(self, app_model, message_id: str): message_id = str(message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb51..e9bd30ba7e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b81..7f44a99ff1 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 837245ecb1..d59aa44718 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.console import console_ns @@ -46,13 +46,14 @@ from models import App from models.model import AppMode from models.workflow import Workflow from services.app_generate_service import AppGenerateService -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -284,7 +285,9 @@ class DraftWorkflowApi(Resource): workflow_service = WorkflowService() try: - environment_variables_list = args.get("environment_variables") or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + args.get("environment_variables") or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -994,6 +997,43 @@ class PublishedAllWorkflowApi(Resource): } +@console_ns.route("/apps//workflows//restore") +class DraftWorkflowRestoreApi(Resource): + @console_ns.doc("restore_workflow_to_draft") + @console_ns.doc(description="Restore a published workflow version into the draft workflow") + @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"}) + @console_ns.response(200, "Workflow restored successfully") + @console_ns.response(400, "Source workflow must be published") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App, workflow_id: str): + current_user, _ = current_account_with_tenant() + workflow_service = WorkflowService() + + try: + workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app_model, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + except ValueError as exc: + raise BadRequest(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): @console_ns.doc("update_workflow_by_id") diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa..493022ffea 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index c2a95ddad2..9e7faa09c5 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 1ed931b0d7..844f3c91ff 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,7 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( @@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: @@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() - session.commit() # Create workspace if needed if ( diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 112e152432..5c7011fd22 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,9 +1,10 @@ import logging +import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -112,6 +113,9 @@ class OAuthCallback(Resource): error_text = e.response.text logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) return {"error": "OAuth process failed"}, 400 + except ValueError as e: + logger.warning("OAuth error with %s", provider, exc_info=True) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}") if invite_token and RegisterService.is_valid_invite_token(invite_token): invitation = RegisterService.get_invitation_by_token(token=invite_token) @@ -176,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 725a8380cd..27c772fbe0 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -3,7 +3,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select +from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -54,7 +55,7 @@ from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum -from models.enums import SegmentStatus +from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -355,7 +356,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -436,7 +437,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -454,7 +455,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -485,7 +486,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): @@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -777,7 +781,7 @@ class DatasetIndexingStatusApi(Resource): class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") @@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource): _, current_tenant_id = current_account_with_tenant() current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) - .count() + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -826,7 +833,7 @@ class DatasetApiKeyApi(Resource): @console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("delete_dataset_api_key") @console_ns.doc(description="Delete dataset API key") @@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource): def delete(self, api_key_id): _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.delete(key) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 0c441553be..897724182f 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -10,7 +10,7 @@ import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field -from sqlalchemy import asc, desc, select +from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -27,6 +27,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db @@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() ) if dataset_process_rule: mode = dataset_process_rule.mode @@ -298,6 +298,7 @@ class DatasetDocumentListApi(Resource): if sort == "hit_count": sub_query = ( sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .where(DocumentSegment.dataset_id == str(dataset_id)) .group_by(DocumentSegment.document_id) .subquery() ) @@ -329,21 +330,23 @@ class DatasetDocumentListApi(Resource): if fetch: for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) document.completed_segments = completed_segments document.total_segments = total_segments @@ -447,7 +450,7 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: @@ -461,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -520,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource): if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -585,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) + file_detail = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) if file_detail is None: @@ -671,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -722,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) - .count() + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) + ) + or 0 ) # Create a dictionary with document attributes and additional fields @@ -1257,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - log = ( - db.session.query(DocumentPipelineExecutionLog) - .filter_by(document_id=document_id) + log = db.session.scalar( + select(DocumentPipelineExecutionLog) + .where(DocumentPipelineExecutionLog.document_id == document_id) .order_by(DocumentPipelineExecutionLog.created_at.desc()) - .first() + .limit(1) ) if not log: return { @@ -1327,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3fd0f3b712..7333fcaa07 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,6 +26,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id): """Helper function to marshal segment and add summary information.""" from services.summary_index_service import SummaryIndexService - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore # Query summary for this segment (only enabled summaries) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None @@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource): # Add summary to each segment segments_with_summary = [] for segment in segments.items: - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore segment_dict["summary"] = summaries.get(segment.id) segments_with_summary.append(segment_dict) @@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): payload = BatchImportPayload.model_validate(console_ns.payload or {}) upload_file_id = payload.upload_file_id - upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1)) if not upload_file: raise NotFound("UploadFile not found.") @@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 99ff49d79d..cd568cf835 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.hit_testing_service import HitTestingService logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ logger = logging.getLogger(__name__) class HitTestingPayload(BaseModel): query: str = Field(max_length=250) - retrieval_model: dict[str, Any] | None = None + retrieval_model: RetrievalModel | None = None external_retrieval_model: dict[str, Any] | None = None attachment_ids: list[str] | None = None diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 6e0cd31b8d..4f31093cfe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource): type = request.args.get("type", default="built-in", type=str) rag_pipeline_service = RagPipelineService() pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) + if pipeline_template is None: + return {"error": "Pipeline template not found from upstream service."}, 404 return pipeline_template, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 51cdcc0c7a..3912cc73ca 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -6,7 +6,7 @@ from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.common.schema import register_schema_models @@ -16,7 +16,11 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) -from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow import ( + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE, + workflow_model, + workflow_pagination_model, +) from controllers.console.app.workflow_run import ( workflow_run_detail_model, workflow_run_node_execution_list_model, @@ -42,7 +46,8 @@ from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline from models.model import EndUser -from services.errors.app import WorkflowHashNotEqualError +from models.workflow import Workflow +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -203,9 +208,12 @@ class DraftRagPipelineApi(Resource): abort(415) payload = DraftWorkflowSyncPayload.model_validate(payload_dict) + rag_pipeline_service = RagPipelineService() try: - environment_variables_list = payload.environment_variables or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + payload.environment_variables or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -213,7 +221,6 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=payload.graph, @@ -705,6 +712,36 @@ class PublishedAllRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows//restore") +class RagPipelineDraftWorkflowRestoreApi(Resource): + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, workflow_id: str): + current_user, _ = current_account_with_tenant() + rag_pipeline_service = RagPipelineService() + + try: + workflow = rag_pipeline_service.restore_published_workflow_to_draft( + pipeline=pipeline, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + # Use a stable, predefined message to keep the 400 response consistent + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): @setup_required diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 3ef1341abc..d533e6c5b1 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar +from sqlalchemy import select + from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]): del kwargs["pipeline_id"] - pipeline = ( - db.session.query(Pipeline) - .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id) - .first() + pipeline = db.session.scalar( + select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1) ) if not pipeline: diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py index da306fbc9d..757061d8dd 100644 --- a/api/controllers/console/explore/banner.py +++ b/api/controllers/console/explore/banner.py @@ -1,9 +1,11 @@ from flask import request from flask_restx import Resource +from sqlalchemy import select from controllers.console import api from controllers.console.explore.wraps import explore_banner_enabled from extensions.ext_database import db +from models.enums import BannerStatus from models.model import ExporleBanner @@ -16,14 +18,18 @@ class BannerApi(Resource): language = request.args.get("language", "en-US") # Build base query for enabled banners - base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled") + base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) # Try to get banners in the requested language - banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort) + ).all() # Fallback to en-US if no banners found and language is not en-US if not banners and language != "en-US": - banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort) + ).all() # Convert banners to serializable format result = [] for banner in banners: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index aca766567f..0740dd0e24 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource): def post(self): payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1) + ) if recommended_app is None: raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == payload.app_id).first() + app = db.session.get(App, payload.app_id) if app is None: raise NotFound("App entity not found") @@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) - .first() + .limit(1) ) if installed_app is None: diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 53970dbd3b..15e1aea361 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant +from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource): app_model=app_model, message_id=message_id, user=current_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 25bb8ed7fe..a8d8036f0f 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -476,7 +477,7 @@ class TrialSitApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() @@ -541,13 +542,7 @@ class AppWorkflowApi(Resource): if not app_model.workflow_id: raise AppUnavailableError() - workflow = ( - db.session.query(Workflow) - .where( - Workflow.id == app_model.workflow_id, - ) - .first() - ) + workflow = db.session.get(Workflow, app_model.workflow_id) return workflow diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 03edb871e6..9d9337e63e 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import abort from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed @@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): _, current_tenant_id = current_account_with_tenant() - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id) - .first() + .limit(1) ) if installed_app is None: @@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): current_user, _ = current_account_with_tenant() - trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1)) if trial_app is None: raise TrialAppNotAllowed() @@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): if app is None: raise TrialAppNotAllowed() - account_trial_app_record = ( - db.session.query(AccountTrialAppRecord) + account_trial_app_record = db.session.scalar( + select(AccountTrialAppRecord) .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) - .first() + .limit(1) ) if account_trial_app_record: if account_trial_app_record.count >= trial_app.trial_limit: diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e099fe0f32..279e4ec502 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -2,6 +2,7 @@ from typing import Literal from flask import request from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from configs import dify_config from controllers.fastopenapi import console_router @@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: def get_setup_status() -> DifySetup | bool | None: if dify_config.EDITION == "SELF_HOSTED": - return db.session.query(DifySetup).first() + return db.session.scalar(select(DifySetup).limit(1)) return True diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 0d8960c9bd..6f93ff1e70 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -212,13 +212,13 @@ class AccountInitApi(Resource): raise ValueError("invitation_code is required") # check invitation code - invitation_code = ( - db.session.query(InvitationCode) + invitation_code = db.session.scalar( + select(InvitationCode) .where( InvitationCode.code == args.invitation_code, InvitationCode.status == InvitationCodeStatus.UNUSED, ) - .first() + .limit(1) ) if not invitation_code: diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index dd302b90d6..e3bf4c95b8 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") - member = db.session.query(Account).where(Account.id == str(member_id)).first() + member = db.session.get(Account, str(member_id)) if member is None: abort(404) else: diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 94be81d94f..88fd2c010f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -7,6 +7,7 @@ from sqlalchemy import select from werkzeug.exceptions import Unauthorized import services +from configs import dify_config from controllers.common.errors import ( FilenameNotExistsError, FileTooLargeError, @@ -29,6 +30,7 @@ from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantStatus from services.account_service import TenantService +from services.billing_service import BillingService, SubscriptionPlan from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.file_service import FileService @@ -108,9 +110,29 @@ class TenantListApi(Resource): current_user, current_tenant_id = current_account_with_tenant() tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] + is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED + is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED + tenant_plans: dict[str, SubscriptionPlan] = {} + + if is_saas: + tenant_ids = [tenant.id for tenant in tenants] + if tenant_ids: + tenant_plans = BillingService.get_plan_bulk(tenant_ids) + if not tenant_plans: + logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path") for tenant in tenants: - features = FeatureService.get_features(tenant.id) + plan: str = CloudPlan.SANDBOX + if is_saas: + tenant_plan = tenant_plans.get(tenant.id) + if tenant_plan: + plan = tenant_plan["plan"] or CloudPlan.SANDBOX + else: + features = FeatureService.get_features(tenant.id) + plan = features.billing.subscription.plan or CloudPlan.SANDBOX + elif not is_enterprise_only: + features = FeatureService.get_features(tenant.id) + plan = features.billing.subscription.plan or CloudPlan.SANDBOX # Create a dictionary with tenant attributes tenant_dict = { @@ -118,7 +140,7 @@ class TenantListApi(Resource): "name": tenant.name, "status": tenant.status, "created_at": tenant.created_at, - "plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX, + "plan": plan, "current": tenant.id == current_tenant_id if current_tenant_id else False, } @@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource): except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant + new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant if new_tenant is None: raise ValueError("Tenant not found") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 014f4c4132..6785ba0c34 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -7,6 +7,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import abort, request +from sqlalchemy import select from configs import dify_config from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError @@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]: @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # check setup - if ( - dify_config.EDITION == "SELF_HOSTED" - and os.environ.get("INIT_PASSWORD") - and not db.session.query(DifySetup).first() - ): - raise NotInitValidateError() - elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): + if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)): + if os.environ.get("INIT_PASSWORD"): + raise NotInitValidateError() raise NotSetupError() return view(*args, **kwargs) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 766d95b3dd..d6e3ebfbcd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in from pydantic import BaseModel +from sqlalchemy import select from sqlalchemy.orm import Session from extensions.ext_database import db @@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = None if is_anonymous: - user_model = ( - session.query(EndUser) + user_model = session.scalar( + select(EndUser) .where( EndUser.session_id == user_id, EndUser.tenant_id == tenant_id, ) - .first() + .limit(1) ) else: - user_model = ( - session.query(EndUser) - .where( - EndUser.id == user_id, - EndUser.tenant_id == tenant_id, - ) - .first() - ) + user_model = session.get(EndUser, user_id) if not user_model: user_model = EndUser( @@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]): if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - try: - tenant_model = ( - db.session.query(Tenant) - .where( - Tenant.id == tenant_id, - ) - .first() - ) - except Exception: - raise ValueError("tenant not found") + tenant_model = db.session.get(Tenant, tenant_id) if not tenant_model: raise ValueError("tenant not found") diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index a5746abafa..ef0a46db63 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from controllers.common.schema import register_schema_models from controllers.console.wraps import setup_required @@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource): def post(self): args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {}) - account = db.session.query(Account).filter_by(email=args.owner_email).first() + account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1)) if account is None: return {"message": "owner account not found."}, 404 diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 4bdcc6832a..7c60b316e8 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]): if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() + kwargs["user"] = db.session.get(EndUser, user_id) return view(*args, **kwargs) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 2aaf920efb..77fee9c142 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from libs.helper import UUIDStrOrEmpty +from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource): app_model=app_model, message_id=message_id, user=end_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 83d07087ab..25b6436a71 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -15,6 +15,7 @@ from controllers.service_api.wraps import ( cloud_edition_billing_rate_limit_check, ) from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag @@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore else: - item["embedding_available"] = False + item["embedding_available"] = False # type: ignore else: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore response = { "data": data, "has_more": len(datasets) == query.limit, @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..595b01a9f2 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -17,6 +17,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields @@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( diff --git a/api/controllers/trigger/webhook.py b/api/controllers/trigger/webhook.py index 22b24271c6..eb579da5d4 100644 --- a/api/controllers/trigger/webhook.py +++ b/api/controllers/trigger/webhook.py @@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str): @bp.route("/webhook-debug/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) def handle_webhook_debug(webhook_id: str): - """Handle webhook debug calls without triggering production workflow execution.""" + """Handle webhook debug calls without triggering production workflow execution. + + The debug webhook endpoint is only for draft inspection flows. It never enqueues + Celery work for the published workflow; instead it dispatches an in-memory debug + event to an active Variable Inspector listener. Returning a clear error when no + listener is registered prevents a misleading 200 response for requests that are + effectively dropped. + """ try: webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True) if error: @@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str): "method": webhook_data.get("method"), }, ) - TriggerDebugEventBus.dispatch( + dispatch_count = TriggerDebugEventBus.dispatch( tenant_id=webhook_trigger.tenant_id, event=event, pool_key=pool_key, ) + if dispatch_count == 0: + logger.warning( + "Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)", + webhook_trigger.webhook_id, + webhook_trigger.tenant_id, + webhook_trigger.app_id, + webhook_trigger.node_id, + ) + return ( + jsonify( + { + "error": "No active debug listener", + "message": ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ), + "execution_url": webhook_trigger.webhook_url, + } + ), + 409, + ) response_data, status_code = WebhookService.generate_webhook_response(node_config) return jsonify(response_data), status_code diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 4e69e56025..36728a47d1 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -8,6 +8,7 @@ from datetime import datetime from flask import Response, request from flask_restx import Resource, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config @@ -147,11 +148,11 @@ class HumanInputFormApi(Resource): def _get_app_site_from_form(form: Form) -> tuple[App, Site]: """Resolve App/Site for the form's app and validate tenant status.""" - app_model = db.session.query(App).where(App.id == form.app_id).first() + app_model = db.session.get(App, form.app_id) if app_model is None or app_model.tenant_id != form.tenant_id: raise NotFoundError("Form not found") - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if site is None: raise Forbidden() diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2b60691949..aa56292614 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper from libs.helper import uuid_value +from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource): app_model=app_model, message_id=message_id, user=end_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index f957229ece..1a0c6d4252 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,6 +1,7 @@ from typing import cast from flask_restx import fields, marshal, marshal_with +from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config @@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource): def get(self, app_model, end_user): """Retrieve app site info.""" # get site - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 70f43b2c83..f04a8df119 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -8,6 +8,7 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService @@ -117,8 +118,10 @@ class DatasetConfigManager: score_threshold=float(score_threshold_val) if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None else None, - reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None, - weights=weights_val if isinstance(weights_val, dict) else None, + reranking_model=cast(RerankingModelDict, reranking_model_val) + if isinstance(reranking_model_val, dict) + else None, + weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None, reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), metadata_filtering_mode=cast( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index ac21577d57..95ea70bc40 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from dify_graph.file import FileUploadConfig from dify_graph.model_runtime.entities.llm_entities import LLMMode from dify_graph.model_runtime.entities.message_entities import PromptMessageRole @@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel): top_k: int | None = None score_threshold: float | None = 0.0 rerank_mode: str | None = "reranking_model" - reranking_model: dict | None = None - weights: dict | None = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None reranking_enabled: bool | None = True metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" metadata_model_config: ModelConfig | None = None diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6583ba51e9..f7b5030d33 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile -from models.enums import CreatorUserRole, MessageStatus +from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent from models.workflow import Workflow @@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, upload_file_id=file["related_id"], created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 77950a832a..a92e3dd2ea 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -74,11 +74,22 @@ class AppGenerateResponseConverter(ABC): for resource in metadata["retriever_resources"]: updated_resources.append( { + "dataset_id": resource.get("dataset_id"), + "dataset_name": resource.get("dataset_name"), + "document_id": resource.get("document_id"), "segment_id": resource.get("segment_id", ""), "position": resource["position"], + "data_source_type": resource.get("data_source_type"), "document_name": resource["document_name"], "score": resource["score"], + "hit_count": resource.get("hit_count"), + "word_count": resource.get("word_count"), + "segment_position": resource.get("segment_position"), + "index_node_hash": resource.get("index_node_hash"), "content": resource["content"], + "page": resource.get("page"), + "title": resource.get("title"), + "files": resource.get("files"), "summary": resource.get("summary"), } ) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 88714f3837..11fcbb7561 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import ( from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from extensions.ext_database import db -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: @@ -419,7 +419,7 @@ class AppRunner: message_id=message_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, url=f"/files/tools/{tool_file.id}", upload_file_id=tool_file.id, created_by_role=( diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 5509764508..621b0d8cf3 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -517,7 +517,7 @@ class WorkflowResponseConverter: snapshot = self._pop_snapshot(event.node_execution_id) start_at = snapshot.start_at if snapshot else event.start_at - finished_at = naive_utc_now() + finished_at = event.finished_at or naive_utc_now() elapsed_time = (finished_at - start_at).total_seconds() inputs, inputs_truncated = self._truncate_mapping(event.inputs) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 4e9a191dae..44d10d79b8 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): end_user_id = None account_id = None if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - from_source = "api" + from_source = ConversationFromSource.API end_user_id = application_generate_entity.user_id else: - from_source = "console" + from_source = ConversationFromSource.CONSOLE account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): message_id=message.id, type=file.type, transfer_method=file.transfer_method, - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, upload_file_id=file.related_id, created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445..bd6e2a0302 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): app_id=self._application_generate_entity.app_config.app_id, workflow_id=self._workflow.id, workflow_run_id=workflow_run_id, - created_from=created_from.value, + created_from=created_from, created_by_role=self._created_by_role, created_by=self._user_id, ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 25d3c8bd2a..adc6cce9af 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -456,6 +456,7 @@ class WorkflowBasedAppRunner: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, inputs=inputs, process_data=process_data, outputs=outputs, @@ -471,6 +472,7 @@ class WorkflowBasedAppRunner: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, outputs=event.node_run_result.outputs, @@ -487,6 +489,7 @@ class WorkflowBasedAppRunner: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, outputs=event.node_run_result.outputs, diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 8899d80db8..d2a36f2a0d 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -335,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime + finished_at: datetime | None = None inputs: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict) @@ -390,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime + finished_at: datetime | None = None inputs: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict) @@ -414,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime + finished_at: datetime | None = None inputs: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 50aed37163..0bd904811a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,9 +4,10 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset -from models.enums import CollectionBindingType +from models.enums import CollectionBindingType, ConversationFromSource from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.dataset_service import DatasetCollectionBindingService @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, @@ -68,9 +69,9 @@ class AnnotationReplyFeature: annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: - from_source = "api" + from_source = ConversationFromSource.API else: - from_source = "console" + from_source = ConversationFromSource.CONSOLE # insert annotation history AppAnnotationService.add_annotation_history( diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..e0f1759e5e 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 2adaf14a35..a881fba877 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): """ """ + def __init__(self) -> None: + super().__init__() + self._paused = False + def on_graph_start(self): - pass + self._paused = False def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. """ if isinstance(event, GraphRunPausedEvent): - pass + self._paused = True def on_graph_end(self, error: Exception | None): """ """ - pass + self._paused = False + + def is_paused(self) -> bool: + return self._paused diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 536ab02eae..62f27060b4 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.enums import MessageFileBelongsTo from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -233,7 +234,7 @@ class MessageCycleManager: task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or "user", + belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER, url=url, ) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index a30491f30c..d95a378575 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._handle_graph_run_paused(event) return - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - return - if isinstance(event, NodeRunRetryEvent): self._handle_node_retry(event) return + if isinstance(event, NodeRunStartedEvent): + self._handle_node_started(event) + return + if isinstance(event, NodeRunSucceededEvent): self._handle_node_succeeded(event) return @@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: domain_execution = self._get_node_execution(event.id) - self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED) + self._update_node_execution( + domain_execution, + event.node_run_result, + WorkflowNodeExecutionStatus.SUCCEEDED, + finished_at=event.finished_at, + ) def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: domain_execution = self._get_node_execution(event.id) @@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): event.node_run_result, WorkflowNodeExecutionStatus.FAILED, error=event.error, + finished_at=event.finished_at, ) def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: @@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): event.node_run_result, WorkflowNodeExecutionStatus.EXCEPTION, error=event.error, + finished_at=event.finished_at, ) def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None: @@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): *, error: str | None = None, update_outputs: bool = True, + finished_at: datetime | None = None, ) -> None: - finished_at = naive_utc_now() + actual_finished_at = finished_at or naive_utc_now() snapshot = self._node_snapshots.get(domain_execution.id) start_at = snapshot.created_at if snapshot else domain_execution.created_at domain_execution.status = status - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0) + domain_execution.finished_at = actual_finished_at + domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0) if error: domain_execution.error = error diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 5971c1e013..24243add17 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -15,6 +15,7 @@ from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import MessageFile, UploadFile from models.tools import ToolFile @@ -81,7 +82,7 @@ class DatasourceFileManager: upload_file = UploadFile( tenant_id=tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=filepath, name=present_filename, size=len(file_binary), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 0279725ff2..a9f2300ba2 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db +from models.enums import CredentialSourceType from models.provider import ( LoadBalancingModelConfig, Provider, @@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) except Exception: @@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "provider", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: @@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="custom_model", + credential_source=CredentialSourceType.CUSTOM_MODEL, session=session, ) except Exception: @@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "custom_model", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() @@ -1421,12 +1422,12 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = s.execute(stmt).scalars().first() if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value + preferred_model_provider.preferred_provider_type = provider_type else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value, + preferred_provider_type=provider_type, ) s.add(preferred_model_provider) s.commit() @@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel): provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "custom_model" + if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL ] load_balancing_enabled = model_setting.load_balancing_enabled @@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel): custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "provider" + if config.credential_source_type != CredentialSourceType.PROVIDER ] load_balancing_enabled = model_setting.load_balancing_enabled diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..06bc366081 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -271,7 +271,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,7 +289,7 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, @@ -303,7 +303,7 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -573,7 +573,7 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -587,7 +587,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +597,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +628,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +654,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,7 +764,7 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 7cb54b2c88..f54461e99a 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance): arize_phoenix_config: ArizeConfig | PhoenixConfig, ): super().__init__(arize_phoenix_config) - import logging - - logging.basicConfig() - logging.getLogger().setLevel(logging.DEBUG) self.arize_phoenix_config = arize_phoenix_config self.tracer, self.processor = setup_tracer(arize_phoenix_config) self.project = arize_phoenix_config.project diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index ef1a3be45b..ed6a7dabbb 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): if field_name == "inputs": data = { "messages": [ - dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v + dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore + for msg in v ] if isinstance(v, list) else v, diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index 0bbb62af93..ec4858ae2e 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/decode/from_identifier", PluginDecodeResponse, - data={"plugin_unique_identifier": plugin_unique_identifier}, - headers={"Content-Type": "application/json"}, + params={"plugin_unique_identifier": plugin_unique_identifier}, ) def fetch_plugin_installation_by_ids( diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index ed34922346..6d2be0ab7a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -195,7 +195,7 @@ class ProviderManager: preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) if preferred_provider_type_record: - preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + preferred_provider_type = preferred_provider_type_record.preferred_provider_type elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: preferred_provider_type = ProviderType.SYSTEM elif custom_configuration.provider or custom_configuration.models: @@ -918,11 +918,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index e182c35b99..790253053d 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -1,9 +1,10 @@ import re +from typing import Any class CleanProcessor: @classmethod - def clean(cls, text: str, process_rule: dict) -> str: + def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str: # default clean # remove invalid symbol text = re.sub(r"<\|", "<", text) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 2b73ef5f26..33eb5f963a 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from typing_extensions import TypedDict + from core.model_manager import ModelInstance, ModelManager from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +class RerankingModelDict(TypedDict): + reranking_provider_name: str + reranking_model_name: str + + +class VectorSettingDict(TypedDict): + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSettingDict(TypedDict): + keyword_weight: float + + +class WeightsDict(TypedDict): + vector_setting: VectorSettingDict + keyword_setting: KeywordSettingDict + + class DataPostProcessor: """Interface for data post-processing document.""" @@ -17,8 +39,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -45,8 +67,8 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE and weights: runner = RerankRunnerFactory.create_rerank_runner( @@ -79,12 +101,14 @@ class DataPostProcessor: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: + def _get_rerank_model_instance( + self, tenant_id: str, reranking_model: RerankingModelDict | None + ) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() - reranking_provider_name = reranking_model.get("reranking_provider_name") - reranking_model_name = reranking_model.get("reranking_model_name") + reranking_provider_name = reranking_model["reranking_provider_name"] + reranking_model_name = reranking_model["reranking_model_name"] if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 0f19ecadc8..b07dc108be 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -4,6 +4,7 @@ from typing import Any import orjson from pydantic import BaseModel from sqlalchemy import select +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -15,6 +16,11 @@ from extensions.ext_storage import storage from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment +class PreSegmentData(TypedDict): + segment: DocumentSegment + keywords: list[str] + + class KeywordTableConfig(BaseModel): max_keywords_per_chunk: int = 10 @@ -128,7 +134,7 @@ class Jieba(BaseKeyword): file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) - def _save_dataset_keyword_table(self, keyword_table): + def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None): keyword_table_dict = { "__type__": "keyword_table", "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, @@ -144,7 +150,7 @@ class Jieba(BaseKeyword): storage.delete(file_key) storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) - def _get_dataset_keyword_table(self) -> dict | None: + def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict @@ -169,14 +175,16 @@ class Jieba(BaseKeyword): return {} - def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]): + def _add_text_to_keyword_table( + self, keyword_table: dict[str, set[str]], id: str, keywords: list[str] + ) -> dict[str, set[str]]: for keyword in keywords: if keyword not in keyword_table: keyword_table[keyword] = set() keyword_table[keyword].add(id) return keyword_table - def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]): + def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]: # get set of ids that correspond to node node_idxs_to_delete = set(ids) @@ -193,7 +201,7 @@ class Jieba(BaseKeyword): return keyword_table - def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): + def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]: keyword_table_handler = JiebaKeywordTableHandler() keywords = keyword_table_handler.extract_keywords(query) @@ -228,7 +236,7 @@ class Jieba(BaseKeyword): keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) - def multi_create_segment_keywords(self, pre_segment_data_list: list): + def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index e8a3a05e19..713319ab9d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,19 +1,20 @@ import concurrent.futures import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, NotRequired from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only +from typing_extensions import TypedDict from configs import dify_config from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments +from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -35,7 +36,49 @@ from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { + +class SegmentAttachmentResult(TypedDict): + attachment_info: AttachmentInfoDict + segment_id: str + + +class SegmentAttachmentInfoResult(TypedDict): + attachment_id: str + attachment_info: AttachmentInfoDict + segment_id: str + + +class ChildChunkDetail(TypedDict): + id: str + content: str + position: int + score: float + + +class SegmentChildMapDetail(TypedDict): + max_score: float + child_chunks: list[ChildChunkDetail] + + +class SegmentRecord(TypedDict): + segment: DocumentSegment + score: NotRequired[float] + child_chunks: NotRequired[list[ChildChunkDetail]] + files: NotRequired[list[AttachmentInfoDict]] + + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod + reranking_enable: bool + reranking_model: RerankingModelDict + reranking_mode: NotRequired[str] + weights: NotRequired[WeightsDict | None] + score_threshold: NotRequired[float] + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -56,11 +99,11 @@ class RetrievalService: query: str, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, - attachment_ids: list | None = None, + attachment_ids: list[str] | None = None, ): if not query and not attachment_ids: return [] @@ -207,8 +250,8 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - all_documents: list, - exceptions: list, + all_documents: list[Document], + exceptions: list[str], document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): @@ -235,10 +278,10 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, - all_documents: list, + reranking_model: RerankingModelDict | None, + all_documents: list[Document], retrieval_method: RetrievalMethod, - exceptions: list, + exceptions: list[str], document_ids_filter: list[str] | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ): @@ -277,8 +320,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH ): data_post_processor = DataPostProcessor( @@ -288,8 +331,8 @@ class RetrievalService: model_manager = ModelManager() is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, - provider=reranking_model.get("reranking_provider_name") or "", - model=reranking_model.get("reranking_model_name") or "", + provider=reranking_model["reranking_provider_name"], + model=reranking_model["reranking_model_name"], model_type=ModelType.RERANK, ) if is_support_vision: @@ -329,10 +372,10 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, - all_documents: list, + reranking_model: RerankingModelDict | None, + all_documents: list[Document], retrieval_method: str, - exceptions: list, + exceptions: list[str], document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): @@ -349,8 +392,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH ): data_post_processor = DataPostProcessor( @@ -459,7 +502,7 @@ class RetrievalService: segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map: dict[str, list[dict[str, Any]]] = {} + attachment_map: dict[str, list[AttachmentInfoDict]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} segment_summary_map: dict[str, str] = {} # Map segment_id to summary content @@ -544,12 +587,12 @@ class RetrievalService: segment_summary_map[summary.chunk_id] = summary.summary_content include_segment_ids = set() - segment_child_map: dict[str, dict[str, Any]] = {} - records: list[dict[str, Any]] = [] + segment_child_map: dict[str, SegmentChildMapDetail] = {} + records: list[SegmentRecord] = [] for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) - attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, []) ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: @@ -560,14 +603,14 @@ class RetrievalService: max_score = summary_score_map.get(segment.id, 0.0) if child_chunks or attachment_infos: - child_chunk_details = [] + child_chunk_details: list[ChildChunkDetail] = [] for child_chunk in child_chunks: child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) if child_document: child_score = child_document.metadata.get("score", 0.0) else: child_score = 0.0 - child_chunk_detail = { + child_chunk_detail: ChildChunkDetail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, @@ -580,7 +623,7 @@ class RetrievalService: if file_document: max_score = max(max_score, file_document.metadata.get("score", 0.0)) - map_detail = { + map_detail: SegmentChildMapDetail = { "max_score": max_score, "child_chunks": child_chunk_details, } @@ -593,7 +636,7 @@ class RetrievalService: "max_score": summary_score, "child_chunks": [], } - record: dict[str, Any] = { + record: SegmentRecord = { "segment": segment, } records.append(record) @@ -617,19 +660,19 @@ class RetrievalService: if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) - record = { + another_record: SegmentRecord = { "segment": segment, "score": max_score, } - records.append(record) + records.append(another_record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: - record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore - record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore + record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"] + record["score"] = segment_child_map[record["segment"].id]["max_score"] if record["segment"].id in attachment_map: - record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] + record["files"] = attachment_map[record["segment"].id] result: list[RetrievalSegments] = [] for record in records: @@ -693,9 +736,9 @@ class RetrievalService: query: str | None = None, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, attachment_id: str | None = None, ): @@ -807,7 +850,7 @@ class RetrievalService: @classmethod def get_segment_attachment_info( cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session - ) -> dict[str, Any] | None: + ) -> SegmentAttachmentResult | None: upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() if upload_file: attachment_binding = ( @@ -816,7 +859,7 @@ class RetrievalService: .first() ) if attachment_binding: - attachment_info = { + attachment_info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -828,8 +871,10 @@ class RetrievalService: return None @classmethod - def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]: - attachment_infos = [] + def get_segment_attachment_infos( + cls, attachment_ids: list[str], session: Session + ) -> list[SegmentAttachmentInfoResult]: + attachment_infos: list[SegmentAttachmentInfoResult] = [] upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] @@ -843,7 +888,7 @@ class RetrievalService: if attachment_bindings: for upload_file in upload_files: attachment_binding = attachment_binding_map.get(upload_file.id) - attachment_info = { + info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -855,7 +900,7 @@ class RetrievalService: attachment_infos.append( { "attachment_id": attachment_binding.attachment_id, - "attachment_info": attachment_info, + "attachment_info": info, "segment_id": attachment_binding.segment_id, } ) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 144d834495..9f5842e449 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore from pymochow.model.database import Database from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore from pymochow.model.schema import ( + AutoBuildRowCountIncrement, Field, FilteringIndex, HNSWParams, @@ -51,6 +52,9 @@ class BaiduConfig(BaseModel): replicas: int = 3 inverted_index_analyzer: str = "DEFAULT_ANALYZER" inverted_index_parser_mode: str = "COARSE_MODE" + auto_build_row_count_increment: int = 500 + auto_build_row_count_increment_ratio: float = 0.05 + rebuild_index_timeout_in_seconds: int = 300 @model_validator(mode="before") @classmethod @@ -107,18 +111,6 @@ class BaiduVector(BaseVector): rows.append(row) table.upsert(rows=rows) - # rebuild vector index after upsert finished - table.rebuild_index(self.vector_index) - timeout = 3600 # 1 hour timeout - start_time = time.time() - while True: - time.sleep(1) - index = table.describe_index(self.vector_index) - if index.state == IndexState.NORMAL: - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") - def text_exists(self, id: str) -> bool: res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: @@ -232,8 +224,14 @@ class BaiduVector(BaseVector): return self._client.database(self._client_config.database) def _table_existed(self) -> bool: - tables = self._db.list_table() - return any(table.table_name == self._collection_name for table in tables) + try: + table = self._db.table(self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + return False + else: + raise + return True def _create_table(self, dimension: int): # Try to grab distributed lock and create table @@ -287,6 +285,11 @@ class BaiduVector(BaseVector): field=VDBField.VECTOR, metric_type=metric_type, params=HNSWParams(m=16, efconstruction=200), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement( + row_count_increment=self._client_config.auto_build_row_count_increment, + row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio, + ), ) ) @@ -335,7 +338,7 @@ class BaiduVector(BaseVector): ) # Wait for table created - timeout = 300 # 5 minutes timeout + timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout start_time = time.time() while True: time.sleep(1) @@ -345,6 +348,20 @@ class BaiduVector(BaseVector): if time.time() - start_time > timeout: raise TimeoutError(f"Table creation timeout after {timeout} seconds") redis_client.set(table_exist_cache_key, 1, ex=3600) + # rebuild vector index immediately after table created, make sure index is ready + table.rebuild_index(self.vector_index) + timeout = 3600 # 1 hour timeout + self._wait_for_index_ready(table, timeout) + + def _wait_for_index_ready(self, table, timeout: int = 3600): + start_time = time.time() + while True: + time.sleep(1) + index = table.describe_index(self.vector_index) + if index.state == IndexState.NORMAL: + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") class BaiduVectorFactory(AbstractVectorFactory): @@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory): replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, + auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT, + auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO, + rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS, ), ) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index c7b6593a8f..df02c584ed 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector): ) ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) docs = [] for doc, score in docs_and_scores: - score_threshold = float(kwargs.get("score_threshold") or 0.0) if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score - docs.append(doc) + docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 56ffb36a2b..3c1d5e015f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -284,27 +285,29 @@ class TidbOnQdrantVector(BaseVector): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse - for node_id in ids: - try: - filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=node_id), - ), - ], - ) - self._client.delete( - collection_name=self._collection_name, - points_selector=FilterSelector(filter=filter), - ) - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - return - # Some other error occurred, so re-raise the exception - else: - raise e + if not ids: + return + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=ids), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e def text_exists(self, id: str) -> bool: all_collection_name = [] @@ -450,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c149241..06b17b9e62 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5ab03a1380..d29d62c93f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r document embeddings used in retrieval-augmented generation workflows. """ +import atexit import datetime import json import logging @@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None _weaviate_client_lock = threading.Lock() +def _shutdown_weaviate_client() -> None: + """ + Best-effort shutdown hook to close the module-level Weaviate client. + + This is registered with atexit so that HTTP/gRPC resources are released + when the Python interpreter exits. + """ + global _weaviate_client + + # Ensure thread-safety when accessing the shared client instance + with _weaviate_client_lock: + client = _weaviate_client + _weaviate_client = None + + if client is not None: + try: + client.close() + except Exception: + # Best-effort cleanup; log at debug level and ignore errors. + logger.debug("Failed to close Weaviate client during shutdown", exc_info=True) + + +# Register the shutdown hook once per process. +atexit.register(_shutdown_weaviate_client) + + class WeaviateConfig(BaseModel): """ Configuration model for Weaviate connection settings. @@ -85,18 +112,6 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes - def __del__(self): - """ - Destructor to properly close the Weaviate client connection. - Prevents connection leaks and resource warnings. - """ - if hasattr(self, "_client") and self._client is not None: - try: - self._client.close() - except Exception as e: - # Ignore errors during cleanup as object is being destroyed - logger.warning("Error closing Weaviate client %s", e, exc_info=True) - def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..cd27113245 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,6 +6,7 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -71,7 +72,7 @@ class DatasetDocumentStore: if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index f6834ab87b..030237559d 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,8 +1,18 @@ from pydantic import BaseModel +from typing_extensions import TypedDict from models.dataset import DocumentSegment +class AttachmentInfoDict(TypedDict): + id: str + name: str + extension: str + mime_type: str + source_url: str + size: int + + class RetrievalChildChunk(BaseModel): """Retrieval segments.""" @@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None - files: list[dict[str, str | int]] | None = None + files: list[AttachmentInfoDict] | None = None summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 5d6223db06..e1ddd2dd96 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,12 +1,38 @@ import json import time -from typing import Any, cast +from typing import Any, NotRequired, cast import httpx +from typing_extensions import TypedDict from extensions.ext_storage import storage +class FirecrawlDocumentData(TypedDict): + title: str | None + description: str | None + source_url: str | None + markdown: str | None + + +class CrawlStatusResponse(TypedDict): + status: str + total: int | None + current: int | None + data: list[FirecrawlDocumentData] + + +class MapResponse(TypedDict): + success: bool + links: list[str] + + +class SearchResponse(TypedDict): + success: bool + data: list[dict[str, Any]] + warning: NotRequired[str] + + class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key @@ -14,7 +40,7 @@ class FirecrawlApp: if self.api_key is None and self.base_url == "https://api.firecrawl.dev": raise ValueError("No API key provided") - def scrape_url(self, url, params=None) -> dict[str, Any]: + def scrape_url(self, url, params=None) -> FirecrawlDocumentData: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape headers = self._prepare_headers() json_data = { @@ -32,9 +58,7 @@ class FirecrawlApp: return self._extract_common_fields(data) elif response.status_code in {402, 409, 500, 429, 408}: self._handle_error(response, "scrape URL") - return {} # Avoid additional exception after handling error - else: - raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post @@ -51,7 +75,7 @@ class FirecrawlApp: self._handle_error(response, "start crawl job") return "" # unreachable - def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def map(self, url: str, params: dict[str, Any] | None = None) -> MapResponse: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map headers = self._prepare_headers() json_data: dict[str, Any] = {"url": url, "integration": "dify"} @@ -60,28 +84,22 @@ class FirecrawlApp: json_data.update(params) response = self._post_request(self._build_url("v2/map"), json_data, headers) if response.status_code == 200: - return cast(dict[str, Any], response.json()) + return cast(MapResponse, response.json()) elif response.status_code in {402, 409, 500, 429, 408}: self._handle_error(response, "start map job") - return {} - else: - raise Exception(f"Failed to start map job. Status code: {response.status_code}") + raise Exception(f"Failed to start map job. Status code: {response.status_code}") - def check_crawl_status(self, job_id) -> dict[str, Any]: + def check_crawl_status(self, job_id) -> CrawlStatusResponse: headers = self._prepare_headers() response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers) if response.status_code == 200: crawl_status_response = response.json() if crawl_status_response.get("status") == "completed": - total = crawl_status_response.get("total", 0) - if total == 0: + # Normalize to avoid None bypassing the zero-guard when the API returns null. + total = crawl_status_response.get("total") or 0 + if total <= 0: raise Exception("Failed to check crawl status. Error: No page found") - data = crawl_status_response.get("data", []) - url_data_list = [] - for item in data: - if isinstance(item, dict) and "metadata" in item and "markdown" in item: - url_data = self._extract_common_fields(item) - url_data_list.append(url_data) + url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers) if url_data_list: file_key = "website_files/" + job_id + ".txt" try: @@ -95,13 +113,45 @@ class FirecrawlApp: return self._format_crawl_status_response( crawl_status_response.get("status"), crawl_status_response, [] ) - else: - self._handle_error(response, "check crawl status") - return {} # unreachable + self._handle_error(response, "check crawl status") + raise RuntimeError("unreachable: _handle_error always raises") + + def _collect_all_crawl_pages( + self, first_page: dict[str, Any], headers: dict[str, str] + ) -> list[FirecrawlDocumentData]: + """Collect all crawl result pages by following pagination links. + + Raises an exception if any paginated request fails, to avoid returning + partial data that is inconsistent with the reported total. + + The number of pages processed is capped at ``total`` (the + server-reported page count) to guard against infinite loops caused by + a misbehaving server that keeps returning a ``next`` URL. + """ + total: int = first_page.get("total") or 0 + url_data_list: list[FirecrawlDocumentData] = [] + current_page = first_page + pages_processed = 0 + while True: + for item in current_page.get("data", []): + if isinstance(item, dict) and "metadata" in item and "markdown" in item: + url_data_list.append(self._extract_common_fields(item)) + next_url: str | None = current_page.get("next") + pages_processed += 1 + if not next_url or pages_processed >= total: + break + response = self._get_request(next_url, headers) + if response.status_code != 200: + self._handle_error(response, "fetch next crawl page") + current_page = response.json() + return url_data_list def _format_crawl_status_response( - self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]] - ) -> dict[str, Any]: + self, + status: str, + crawl_status_response: dict[str, Any], + url_data_list: list[FirecrawlDocumentData], + ) -> CrawlStatusResponse: return { "status": status, "total": crawl_status_response.get("total"), @@ -109,7 +159,7 @@ class FirecrawlApp: "data": url_data_list, } - def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]: + def _extract_common_fields(self, item: dict[str, Any]) -> FirecrawlDocumentData: return { "title": item.get("metadata", {}).get("title"), "description": item.get("metadata", {}).get("description"), @@ -117,7 +167,7 @@ class FirecrawlApp: "markdown": item.get("markdown"), } - def _prepare_headers(self) -> dict[str, Any]: + def _prepare_headers(self) -> dict[str, str]: return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _build_url(self, path: str) -> str: @@ -150,10 +200,10 @@ class FirecrawlApp: error_message = response.text or "Unknown error occurred" raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return] - def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def search(self, query: str, params: dict[str, Any] | None = None) -> SearchResponse: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search headers = self._prepare_headers() - json_data = { + json_data: dict[str, Any] = { "query": query, "limit": 5, "lang": "en", @@ -170,12 +220,10 @@ class FirecrawlApp: json_data.update(params) response = self._post_request(self._build_url("v2/search"), json_data, headers) if response.status_code == 200: - response_data = response.json() + response_data: SearchResponse = response.json() if not response_data.get("success"): raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}") - return cast(dict[str, Any], response_data) + return response_data elif response.status_code in {402, 409, 500, 429, 408}: self._handle_error(response, "perform search") - return {} # Avoid additional exception after handling error - else: - raise Exception(f"Failed to perform search. Status code: {response.status_code}") + raise Exception(f"Failed to perform search. Status code: {response.status_code}") diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 6aabcac704..9abdb31325 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import UploadFile @@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor): # save file to db upload_file = UploadFile( tenant_id=self._tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=file_key, size=len(img_bytes), diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 7cf6c4d289..e8da866870 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -1,10 +1,11 @@ import json from collections.abc import Generator -from typing import Union +from typing import Any, Union from urllib.parse import urljoin import httpx from httpx import Response +from typing_extensions import TypedDict from core.rag.extractor.watercrawl.exceptions import ( WaterCrawlAuthenticationError, @@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import ( ) +class SpiderOptions(TypedDict): + max_depth: int + page_limit: int + allowed_domains: list[str] + exclude_paths: list[str] + include_paths: list[str] + + +class PageOptions(TypedDict): + exclude_tags: list[str] + include_tags: list[str] + wait_time: int + include_html: bool + only_main_content: bool + include_links: bool + timeout: int + accept_cookies_selector: str + locale: str + actions: list[Any] + + class BaseAPIClient: def __init__(self, api_key, base_url): self.api_key = api_key @@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient): def create_crawl_request( self, url: Union[list, str] | None = None, - spider_options: dict | None = None, - page_options: dict | None = None, - plugin_options: dict | None = None, + spider_options: SpiderOptions | None = None, + page_options: PageOptions | None = None, + plugin_options: dict[str, Any] | None = None, ): data = { # 'urls': url if isinstance(url, list) else [url], @@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient): def scrape_url( self, url: str, - page_options: dict | None = None, - plugin_options: dict | None = None, + page_options: PageOptions | None = None, + plugin_options: dict[str, Any] | None = None, sync: bool = True, prefetched: bool = True, ): diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index fe983aa86a..81c19005db 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -2,16 +2,39 @@ from collections.abc import Generator from datetime import datetime from typing import Any -from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient +from typing_extensions import TypedDict + +from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient + + +class WatercrawlDocumentData(TypedDict): + title: str | None + description: str | None + source_url: str | None + markdown: str | None + + +class CrawlJobResponse(TypedDict): + status: str + job_id: str | None + + +class WatercrawlCrawlStatusResponse(TypedDict): + status: str + job_id: str | None + total: int + current: int + data: list[WatercrawlDocumentData] + time_consuming: float class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: dict | Any | None = None): + def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse: options = options or {} - spider_options = { + spider_options: SpiderOptions = { "max_depth": 1, "page_limit": 1, "allowed_domains": [], @@ -25,7 +48,7 @@ class WaterCrawlProvider: spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else [] wait_time = options.get("wait_time", 1000) - page_options = { + page_options: PageOptions = { "exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [], "include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [], "wait_time": max(1000, wait_time), # minimum wait time is 1 second @@ -41,9 +64,9 @@ class WaterCrawlProvider: return {"status": "active", "job_id": result.get("uuid")} - def get_crawl_status(self, crawl_request_id): + def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse: response = self.client.get_crawl_request(crawl_request_id) - data = [] + data: list[WatercrawlDocumentData] = [] if response["status"] in ["new", "running"]: status = "active" else: @@ -67,7 +90,7 @@ class WaterCrawlProvider: "time_consuming": time_consuming, } - def get_crawl_url_data(self, job_id, url) -> dict | None: + def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None: if not job_id: return self.scrape_url(url) @@ -82,11 +105,11 @@ class WaterCrawlProvider: return None - def scrape_url(self, url: str): + def scrape_url(self, url: str) -> WatercrawlDocumentData: response = self.client.scrape_url(url=url, sync=True, prefetched=True) return self._structure_data(response) - def _structure_data(self, result_object: dict): + def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData: if isinstance(result_object.get("result", {}), str): raise ValueError("Invalid result object. Expected a dictionary.") @@ -98,7 +121,9 @@ class WaterCrawlProvider: "markdown": result_object.get("result", {}).get("markdown"), } - def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]: + def _get_results( + self, crawl_request_id: str, query_params: dict | None = None + ) -> Generator[WatercrawlDocumentData, None, None]: page = 0 page_size = 100 diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index d6b6ca35be..052fca930d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import UploadFile @@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor): # save file to db upload_file = UploadFile( tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=file_key, size=0, @@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor): # save file to db upload_file = UploadFile( tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=file_key, size=0, @@ -365,7 +366,7 @@ class WordExtractor(BaseExtractor): paragraph_content = [] # State for legacy HYPERLINK fields hyperlink_field_url = None - hyperlink_field_text_parts: list = [] + hyperlink_field_text_parts: list[str] = [] is_collecting_field_text = False # Iterate through paragraph elements in document order for child in paragraph._element: diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index a7c42c5a4e..a6d1db214b 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,8 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview from models.dataset import Dataset, Document, DocumentSegment @@ -51,7 +53,7 @@ class IndexProcessor: original_document_id: str, chunks: Mapping[str, Any], batch: Any, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, ): with session_factory.create_session() as session: document = session.query(Document).filter_by(id=document_id).first() @@ -131,7 +133,12 @@ class IndexProcessor: } def get_preview_output( - self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None + self, + chunks: Any, + dataset_id: str, + document_id: str, + chunk_structure: str, + summary_index_setting: SummaryIndexSettingDict | None, ) -> Preview: doc_language = None with session_factory.create_session() as session: @@ -153,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index f2191f3702..a435dfc46a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -7,14 +7,16 @@ import os import re from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, NotRequired, Optional from urllib.parse import unquote, urlparse import httpx +from typing_extensions import TypedDict from configs import dify_config from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import AttachmentDocument, Document @@ -35,6 +37,13 @@ if TYPE_CHECKING: from core.model_manager import ModelInstance +class SummaryIndexSettingDict(TypedDict): + enable: bool + model_name: NotRequired[str] + model_provider_name: NotRequired[str] + summary_prompt: NotRequired[str] + + class BaseIndexProcessor(ABC): """Interface for extract files.""" @@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9c21dad488..726cc062f6 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector @@ -21,8 +22,8 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -116,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -154,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -252,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) @@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def generate_summary( tenant_id: str, text: str, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, segment_id: str | None = None, document_language: str | None = None, ) -> tuple[str, LLMUsage]: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 367f0aec00..70504e6e50 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -11,14 +11,15 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -127,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -165,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -331,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: @@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 503cce2132..6874603a83 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -15,13 +15,14 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -140,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ): # Set search parameters. results = RetrievalService.retrieve( @@ -223,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: @@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 15486e1fb8..52061fd93d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -31,9 +31,9 @@ from core.ops.utils import measure_time from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition @@ -87,7 +87,7 @@ from models.enums import CreatorUserRole, DatasetQuerySource from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -591,7 +591,7 @@ class DatasetRetrieval: user_id: str, user_from: str, query: str, - available_datasets: list, + available_datasets: list[Dataset], model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -633,15 +633,15 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) - dataset = db.session.scalar(dataset_stmt) - if dataset: + selected_dataset = db.session.scalar(dataset_stmt) + if selected_dataset: results = [] - if dataset.provider == "external": + if selected_dataset.provider == "external": external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id=dataset.tenant_id, + tenant_id=selected_dataset.tenant_id, dataset_id=dataset_id, query=query, - external_retrieval_parameters=dataset.retrieval_model, + external_retrieval_parameters=selected_dataset.retrieval_model, metadata_condition=metadata_condition, ) for external_document in external_documents: @@ -654,24 +654,28 @@ class DatasetRetrieval: document.metadata["score"] = external_document.get("score") document.metadata["title"] = external_document.get("title") document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + document.metadata["dataset_name"] = selected_dataset.name results.append(document) else: if metadata_condition and not metadata_filter_document_ids: return [] document_ids_filter = None if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) + document_ids = metadata_filter_document_ids.get(selected_dataset.id, []) if document_ids: document_ids_filter = document_ids else: return [] - retrieval_model_config = dataset.retrieval_model or default_retrieval_model + retrieval_model_config: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model) + if selected_dataset.retrieval_model + else default_retrieval_model + ) # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -690,7 +694,7 @@ class DatasetRetrieval: with measure_time() as timer: results = RetrievalService.retrieve( retrieval_method=retrieval_method, - dataset_id=dataset.id, + dataset_id=selected_dataset.id, query=query, top_k=top_k, score_threshold=score_threshold, @@ -722,13 +726,13 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, + available_datasets: list[Dataset], query: str | None, top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict[str, Any] | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reranking_enable: bool = True, message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, @@ -748,7 +752,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -1024,7 +1028,7 @@ class DatasetRetrieval: dataset_id: str, query: str, top_k: int, - all_documents: list, + all_documents: list[Document], document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, attachment_ids: list[str] | None = None, @@ -1058,9 +1062,13 @@ class DatasetRetrieval: all_documents.append(document) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -1132,7 +1140,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config - default_retrieval_model = { + default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -1141,7 +1149,11 @@ class DatasetRetrieval: } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model or default_retrieval_model + retrieval_model_config: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) # get top k top_k = retrieval_model_config["top_k"] @@ -1181,8 +1193,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"], + reranking_model_name=retrieve_config.reranking_model["reranking_model_name"], ) tools.append(tool) @@ -1286,7 +1298,7 @@ class DatasetRetrieval: def get_metadata_filter_condition( self, - dataset_ids: list, + dataset_ids: list[str], query: str, tenant_id: str, user_id: str, @@ -1388,7 +1400,7 @@ class DatasetRetrieval: return output def _automatic_metadata_filter_func( - self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig + self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig ) -> list[dict[str, Any]] | None: # get all metadata field metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) @@ -1586,7 +1598,7 @@ class DatasetRetrieval: ) def _get_prompt_template( - self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str + self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str ): model_mode = ModelMode(mode) input_text = query @@ -1678,15 +1690,15 @@ class DatasetRetrieval: def _multiple_retrieve_thread( self, flask_app: Flask, - available_datasets: list, + available_datasets: list[Dataset], metadata_condition: MetadataCondition | None, metadata_filter_document_ids: dict[str, list[str]] | None, all_documents: list[Document], tenant_id: str, reranking_enable: bool, reranking_mode: str, - reranking_model: dict | None, - weights: dict[str, Any] | None, + reranking_model: RerankingModelDict | None, + weights: WeightsDict | None, top_k: int, score_threshold: float, query: str | None, diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 79d7821b4e..6f120bd471 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,8 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService from tasks.generate_summary_index_task import generate_summary_index_task @@ -11,12 +13,16 @@ logger = logging.getLogger(__name__) class SummaryIndex: def generate_and_vectorize_summary( - self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + self, + dataset_id: str, + document_id: str, + is_preview: bool, + summary_index_setting: SummaryIndexSettingDict | None = None, ) -> None: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 00f5931088..bcf58394ba 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -50,7 +50,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id or "", - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, ) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0f0eacbdc4..64212a2636 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool from dify_graph.file import FileType from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile logger = logging.getLogger(__name__) @@ -352,7 +352,7 @@ class ToolEngine: message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, url=message.url, upload_file_id=tool_file_id, created_by_role=( diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 90d5a647e9..250dd91bfd 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -38,7 +38,7 @@ class ToolLabelManager: db.session.add( ToolLabelBinding( tool_id=provider_id, - tool_type=controller.provider_type.value, + tool_type=controller.provider_type, label_name=label, ) ) @@ -58,7 +58,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") stmt = select(ToolLabelBinding.label_name).where( ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, + ToolLabelBinding.tool_type == controller.provider_type, ) labels = db.session.scalars(stmt).all() diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..75b923fd8b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 2969fafe89..f3d390ed59 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -4,9 +4,11 @@ from pydantic import BaseModel, Field from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -20,9 +22,9 @@ from services.external_knowledge_service import ExternalDatasetService class DefaultRetrievalModelDict(TypedDict): search_method: RetrievalMethod reranking_enable: bool - reranking_model: dict[str, str] + reranking_model: RerankingModelDict reranking_mode: NotRequired[str] - weights: NotRequired[dict[str, object] | None] + weights: NotRequired[WeightsDict | None] score_threshold: NotRequired[float] top_k: int score_threshold_enabled: bool @@ -139,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -172,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8f958563bd..373bd1b1c8 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -9,6 +9,7 @@ from decimal import Decimal from typing import cast from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType from dify_graph.model_runtime.entities.llm_entities import LLMResult from dify_graph.model_runtime.entities.message_entities import PromptMessage from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -78,7 +79,7 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context diff --git a/api/core/trigger/constants.py b/api/core/trigger/constants.py index bfa45c3f2b..192faa2d3e 100644 --- a/api/core/trigger/constants.py +++ b/api/core/trigger/constants.py @@ -3,7 +3,6 @@ from typing import Final TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook" TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule" TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin" -TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info" TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset( { diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8b00746268..8d2e9bf3cb 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -2,6 +2,7 @@ from typing import Literal, Union from pydantic import BaseModel +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData @@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData): chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None - summary_index_setting: dict | None = None + summary_index_setting: SummaryIndexSettingDict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 0a74847bc1..4ea9091c5b 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any from core.rag.index_processor.index_processor import IndexProcessor +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.graph_config import NodeConfigDict @@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): is_preview: bool, batch: Any, chunks: Mapping[str, Any], - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, ): if not document_id: raise KnowledgeIndexNodeError("document_id is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 9c3b9aacbf..80f59140be 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict @@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - reranking_model = None - weights = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None match node_data.multiple_retrieval_config.reranking_mode: case "reranking_model": if node_data.multiple_retrieval_config.reranking_model: diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index f964f79582..e1311ab962 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from dify_graph.model_runtime.entities import LLMUsage from dify_graph.nodes.llm.entities import ModelConfig @@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel): top_k: int = Field(default=0, description="Number of top results to return") score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") reranking_mode: str = Field(default="reranking_model", description="Reranking strategy") - reranking_model: dict | None = Field(default=None, description="Reranking model configuration") - weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking") + reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration") + weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking") reranking_enable: bool = Field(default=True, description="Whether reranking is enabled") attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval") diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 2048a53064..118c2f2668 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping -from typing import Any, cast +from typing import Any -from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey @@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]): # Get trigger data passed when workflow was triggered metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { - cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): { + WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { "provider_id": self.node_data.provider_id, "event_name": self.node_data.event_name, "plugin_unique_identifier": self.node_data.plugin_unique_identifier, diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index 06653bebb6..cfb135cbb0 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -245,6 +245,9 @@ _END_STATE = frozenset( class WorkflowNodeExecutionMetadataKey(StrEnum): """ Node Run Metadata Key. + + Values in this enum are persisted as execution metadata and must stay in sync + with every node that writes `NodeRunResult.metadata`. """ TOTAL_TOKENS = "total_tokens" @@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output DATASOURCE_INFO = "datasource_info" + TRIGGER_INFO = "trigger_info" COMPLETED_REASON = "completed_reason" # completed reason for loop node diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py index d4ee2922ec..e206f21592 100644 --- a/api/dify_graph/graph_engine/error_handler.py +++ b/api/dify_graph/graph_engine/error_handler.py @@ -159,6 +159,7 @@ class ErrorHandler: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.EXCEPTION, inputs=event.node_run_result.inputs, @@ -198,6 +199,7 @@ class ErrorHandler: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.EXCEPTION, inputs=event.node_run_result.inputs, diff --git a/api/dify_graph/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py index 5c5d0fe5b9..988c20d72a 100644 --- a/api/dify_graph/graph_engine/worker.py +++ b/api/dify_graph/graph_engine/worker.py @@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final from typing_extensions import override from dify_graph.context import IExecutionContext +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event +from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event +from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node +from libs.datetime_utils import naive_utc_now from .ready_queue import ReadyQueue @@ -65,6 +68,7 @@ class Worker(threading.Thread): self._stop_event = threading.Event() self._layers = layers if layers is not None else [] self._last_task_time = time.time() + self._current_node_started_at: datetime | None = None def stop(self) -> None: """Signal the worker to stop processing.""" @@ -104,18 +108,15 @@ class Worker(threading.Thread): self._last_task_time = time.time() node = self._graph.nodes[node_id] try: + self._current_node_started_at = None self._execute_node(node) self._ready_queue.task_done() except Exception as e: - error_event = NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=str(e), - start_at=datetime.now(), + self._event_queue.put( + self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) ) - self._event_queue.put(error_event) + finally: + self._current_node_started_at = None def _execute_node(self, node: Node) -> None: """ @@ -136,6 +137,8 @@ class Worker(threading.Thread): try: node_events = node.run() for event in node_events: + if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: + self._current_node_started_at = event.start_at self._event_queue.put(event) if is_node_result_event(event): result_event = event @@ -149,6 +152,8 @@ class Worker(threading.Thread): try: node_events = node.run() for event in node_events: + if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: + self._current_node_started_at = event.start_at self._event_queue.put(event) if is_node_result_event(event): result_event = event @@ -177,3 +182,24 @@ class Worker(threading.Thread): except Exception: # Silently ignore layer errors to prevent disrupting node execution continue + + def _build_fallback_failure_event( + self, node: Node, error: Exception, *, started_at: datetime | None = None + ) -> NodeRunFailedEvent: + """Build a failed event when worker-level execution aborts before a node emits its own result event.""" + failure_time = naive_utc_now() + error_message = str(error) + return NodeRunFailedEvent( + id=node.execution_id, + node_id=node.id, + node_type=node.node_type, + in_iteration_id=None, + error=error_message, + start_at=started_at or failure_time, + finished_at=failure_time, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + error_type=type(error).__name__, + ), + ) diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py index 8552254627..df19d6c03b 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase): class NodeRunSucceededEvent(GraphNodeEventBase): start_at: datetime = Field(..., description="node start time") + finished_at: datetime | None = Field(default=None, description="node finish time") class NodeRunFailedEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") + finished_at: datetime | None = Field(default=None, description="node finish time") class NodeRunExceptionEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") + finished_at: datetime | None = Field(default=None, description="node finish time") class NodeRunRetryEvent(NodeRunStartedEvent): diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index c6f54ce672..56b46a5894 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]): error=str(e), error_type="WorkflowNodeError", ) + finished_at = naive_utc_now() yield NodeRunFailedEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=result, error=str(e), ) @@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]): return self._node_data def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: + finished_at = naive_utc_now() match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( @@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]): node_id=self.id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=result, error=result.error, ) @@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]): node_id=self.id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=result, ) case _: @@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: + finished_at = naive_utc_now() match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]): node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=event.node_run_result, ) case WorkflowNodeExecutionStatus.FAILED: @@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]): node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=event.node_run_result, error=event.node_run_result.error, ) diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 486ae241ee..3e5253d809 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -101,6 +101,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, http_request_config=self._http_request_config, + # Must be 0 to disable executor-level retries, as the graph engine handles them. + # This is critical to prevent nested retries. + max_retries=0, ssl_verify=self.node_data.ssl_verify, http_client=self._http_client, file_manager=self._file_manager, diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index f63ba0bc48..033ec8672f 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): future_to_index: dict[ Future[ tuple[ - datetime, + float, list[GraphNodeEventBase], object | None, dict[str, Variable], @@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): try: result = future.result() ( - iter_start_at, + iteration_duration, events, output_value, conversation_snapshot, @@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # Yield all events from this iteration yield from events - # Update tokens and timing - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # The worker computes duration before we replay buffered events here, + # so slow downstream consumers don't inflate per-iteration timing. + iter_run_map[str(index)] = iteration_duration usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) @@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): index: int, item: object, execution_context: "IExecutionContext", - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: + ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with execution_context: iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): conversation_snapshot = self._extract_conversation_variable_snapshot( variable_pool=graph_engine.graph_runtime_state.variable_pool ) + iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() return ( - iter_start_at, + iteration_duration, events, output_value, conversation_snapshot, diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index 073dce232f..8682c3682c 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -1,6 +1,9 @@ from __future__ import annotations -from collections.abc import Sequence +import json +import logging +import re +from collections.abc import Mapping, Sequence from typing import Any, cast from core.model_manager import ModelInstance @@ -36,6 +39,11 @@ from .exc import ( ) from .protocols import TemplateRenderer +logger = logging.getLogger(__name__) + +VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") +MAX_RESOLVED_VALUE_LENGTH = 1024 + def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( @@ -256,9 +264,13 @@ def fetch_prompt_messages( ): continue prompt_message_content.append(content_item) - if prompt_message_content: + if not prompt_message_content: + continue + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: + prompt_message.content = prompt_message_content[0].data + else: prompt_message.content = prompt_message_content - filtered_prompt_messages.append(prompt_message) + filtered_prompt_messages.append(prompt_message) elif not prompt_message.is_empty(): filtered_prompt_messages.append(prompt_message) @@ -471,3 +483,61 @@ def _append_file_prompts( prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + + +def _coerce_resolved_value(raw: str) -> int | float | bool | str: + """Try to restore the original type from a resolved template string. + + Variable references are always resolved to text, but completion params may + expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to + the ``temperature`` parameter). This helper attempts a JSON parse so that + ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not + valid JSON literals are returned as-is. + """ + stripped = raw.strip() + if not stripped: + return raw + + try: + parsed: object = json.loads(stripped) + except (json.JSONDecodeError, ValueError): + return raw + + if isinstance(parsed, (int, float, bool)): + return parsed + return raw + + +def resolve_completion_params_variables( + completion_params: Mapping[str, Any], + variable_pool: VariablePool, +) -> dict[str, Any]: + """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. + + Security notes: + - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to + prevent denial-of-service through excessively large variable payloads. + - This follows the same ``VariablePool.convert_template`` pattern used across + Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream + model plugin receives these values as structured JSON key-value pairs — they + are never concatenated into raw HTTP headers or SQL queries. + - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are + restored to their native type rather than sent as a bare string. + """ + resolved: dict[str, Any] = {} + for key, value in completion_params.items(): + if isinstance(value, str) and VARIABLE_PATTERN.search(value): + segment_group = variable_pool.convert_template(value) + text = segment_group.text + if len(text) > MAX_RESOLVED_VALUE_LENGTH: + logger.warning( + "Resolved value for param '%s' truncated from %d to %d chars", + key, + len(text), + MAX_RESOLVED_VALUE_LENGTH, + ) + text = text[:MAX_RESOLVED_VALUE_LENGTH] + resolved[key] = _coerce_resolved_value(text) + else: + resolved[key] = value + return resolved diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5ed90ed7e3..a5492aee6b 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]): # fetch model config model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) model_name = model_instance.model_name model_provider = model_instance.provider model_stop = model_instance.stop diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 3913a27697..e6e8a44d06 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise InvalidModelTypeError("Model is not a Large Language Model") diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 59d0a2a4d8..928618fdbc 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variables = {"query": query} # fetch model instance model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) memory = self._memory # fetch instruction node_data.instruction = node_data.instruction or "" diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 76de5a0740..b7e7a6e60f 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -3,6 +3,7 @@ import logging import time import click +from sqlalchemy import select from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -24,13 +25,11 @@ def handle(sender, **kwargs): for document_id in document_ids: logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = ( - db.session.query(Document) - .where( + document = db.session.scalar( + select(Document).where( Document.id == document_id, Document.dataset_id == dataset_id, ) - .first() ) if not document: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index b70c2183d2..4709534ae6 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,6 +1,6 @@ from typing import Any, cast -from sqlalchemy import select +from sqlalchemy import delete, select from events.app_event import app_model_config_was_updated from extensions.ext_database import db @@ -31,9 +31,9 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).where( - AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id - ).delete() + db.session.execute( + delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id) + ) if added_dataset_ids: for dataset_id in added_dataset_ids: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 92bc9db075..20852b818e 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,6 +1,6 @@ from typing import cast -from sqlalchemy import select +from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from dify_graph.nodes import BuiltinNodeTypes @@ -31,9 +31,9 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).where( - AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id - ).delete() + db.session.execute( + delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id) + ) if added_dataset_ids: for dataset_id in added_dataset_ids: diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 74299956c0..02e50a90fc 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -3,6 +3,7 @@ import json import flask_login from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in +from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login): if admin_api_key and admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == workspace_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role == "owner") - .one_or_none() - ) + ).one_or_none() if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).filter_by(id=ta.account_id).first() + account = db.session.scalar(select(Account).where(Account.id == ta.account_id)) if account: account.current_tenant = tenant return account @@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login): end_user_id = decoded.get("end_user_id") if not end_user_id: raise Unauthorized("Invalid Authorization token.") - end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if not end_user: raise NotFound("End user not found.") return end_user @@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login): decoded = PassportService().verify(auth_token) end_user_id = decoded.get("end_user_id") if end_user_id: - end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if not end_user: raise NotFound("End user not found.") return end_user @@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login): server_code = request.view_args.get("server_code") if request.view_args else None if not server_code: raise Unauthorized("Invalid Authorization token.") - app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() + app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1)) if not app_mcp_server: raise NotFound("App MCP server not found.") - end_user = ( - db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first() + end_user = db.session.scalar( + select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1) ) if not end_user: raise NotFound("End user not found.") diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 83c5c2d12f..96f5915ff0 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage): kwargs = kwargs or _get_opendal_kwargs(scheme=scheme) if scheme == "fs": - root = kwargs.get("root", "storage") + root = kwargs.setdefault("root", "storage") Path(root).mkdir(parents=True, exist_ok=True) retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index ef55fe53c5..cb07ba58ae 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -424,13 +424,11 @@ def _build_from_datasource_file( datasource_file_id = mapping.get("datasource_file_id") if not datasource_file_id: raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = ( - db.session.query(UploadFile) - .where( + datasource_file = db.session.scalar( + select(UploadFile).where( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - .first() ) if datasource_file is None: diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index d4cb3e9971..8eeac37232 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -125,7 +125,8 @@ class BroadcastChannel(Protocol): a specific topic, all subscription should receive the published message. There are no restriction for the persistence of messages. Once a subscription is created, it - should receive all subsequent messages published. + should receive all subsequent messages published. However, a subscription should not receive + any message published before the subscription is established. `BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads. """ diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca..aaeaf76f7b 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -64,7 +64,10 @@ class _StreamsSubscription(Subscription): self._client = client self._key = key self._closed = threading.Event() - self._last_id = "0-0" + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + self._last_id = "$" self._queue: queue.Queue[object] = queue.Queue() self._start_lock = threading.Lock() self._listener: threading.Thread | None = None diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d..dce332b01d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -18,15 +18,23 @@ if TYPE_CHECKING: from models.model import EndUser +def _resolve_current_user() -> EndUser | Account | None: + """ + Resolve the current user proxy to its underlying user object. + This keeps unit tests working when they patch `current_user` directly + instead of bootstrapping a full Flask-Login manager. + """ + user_proxy = current_user + get_current_object = getattr(user_proxy, "_get_current_object", None) + return get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + def current_account_with_tenant(): """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. """ - user_proxy = current_user - - get_current_object = getattr(user_proxy, "_get_current_object", None) - user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + user = _resolve_current_user() if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") @@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) - user = _get_user() + user = _resolve_current_user() if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. check_csrf_token(request, user.id) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index efce13f6f1..1afb42304d 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,16 +1,19 @@ +import logging import sys import urllib.parse from dataclasses import dataclass from typing import NotRequired import httpx -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict +logger = logging.getLogger(__name__) + JsonObject = dict[str, object] JsonObjectList = list[JsonObject] @@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False): class GitHubRawUserInfo(TypedDict): id: int | str login: str - name: NotRequired[str] - email: NotRequired[str] + name: NotRequired[str | None] + email: NotRequired[str | None] class GoogleRawUserInfo(TypedDict): @@ -127,9 +130,14 @@ class GitHubOAuth(OAuth): response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) - email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) - email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) - primary_email = next((email for email in email_info if email.get("primary") is True), None) + try: + email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) + email_response.raise_for_status() + email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) + primary_email = next((email for email in email_info if email.get("primary") is True), None) + except (httpx.HTTPStatusError, ValidationError): + logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True) + primary_email = None return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} @@ -137,8 +145,11 @@ class GitHubOAuth(OAuth): payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) email = payload.get("email") if not email: - email = f"{payload['id']}+{payload['login']}@users.noreply.github.com" - return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email) + raise ValueError( + 'Dify currently not supports the "Keep my email addresses private" feature,' + " please disable it and login again" + ) + return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email) class GoogleOAuth(OAuth): diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e6984..e323ccfd7f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -43,7 +43,9 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -135,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -494,7 +496,9 @@ class Document(Base): ) doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) + doc_form: Mapped[IndexStructureType] = mapped_column( + EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'") + ) doc_language = mapped_column(String(255), nullable=True) need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @@ -998,7 +1002,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/enums.py b/api/models/enums.py index 6af74cddc8..cdec7b2f12 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" + @classmethod + def _missing_(cls, value): + if value == "end-user": + return cls.END_USER + else: + return super()._missing_(value) + class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" @@ -151,6 +158,13 @@ class FeedbackFromSource(StrEnum): ADMIN = "admin" +class FeedbackRating(StrEnum): + """MessageFeedback rating""" + + LIKE = "like" + DISLIKE = "dislike" + + class InvokeFrom(StrEnum): """How a conversation/message was invoked""" @@ -208,6 +222,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" @@ -309,3 +330,10 @@ class ProviderQuotaType(StrEnum): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class ApiTokenType(StrEnum): + """API Token type""" + + APP = "app" + DATASET = "dataset" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py index d0bd34efec..b2d09a7732 100644 --- a/api/models/execution_extra_content.py +++ b/api/models/execution_extra_content.py @@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent): form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) @classmethod - def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": - return cls(form_id=form_id, message_id=message_id) + def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent": + return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id) form: Mapped["HumanInputForm"] = relationship( "HumanInputForm", diff --git a/api/models/model.py b/api/models/model.py index fe70fcd401..68ff37bcaa 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -21,15 +21,32 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from dify_graph.file import helpers as file_helpers +from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db -from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus +from .enums import ( + ApiTokenType, + AppMCPServerStatus, + AppStatus, + BannerStatus, + ConversationFromSource, + ConversationStatus, + CreatorUserRole, + FeedbackFromSource, + FeedbackRating, + InvokeFrom, + MessageChainType, + MessageFileBelongsTo, + MessageStatus, + ProviderQuotaType, + TagType, +) from .provider_ids import GenericProviderID from .types import EnumText, LongText, StringUUID @@ -572,7 +589,9 @@ class AppModelConfig(TypeBase): __tablename__ = "app_model_configs" __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) @@ -921,12 +940,17 @@ class AccountTrialAppRecord(Base): class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) - status: Mapped[str] = mapped_column( - sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + status: Mapped[BannerStatus] = mapped_column( + EnumText(BannerStatus, length=255), + nullable=False, + server_default=sa.text("'enabled'::character varying"), + default=BannerStatus.ENABLED, ) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1007,10 +1031,12 @@ class Conversation(Base): # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(String(255), nullable=True) + invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True) # ref: ConversationSource. - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + from_source: Mapped[ConversationFromSource] = mapped_column( + EnumText(ConversationFromSource, length=255), nullable=False + ) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) read_at = mapped_column(sa.DateTime) @@ -1153,7 +1179,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", - MessageFeedback.rating == "like", + MessageFeedback.rating == FeedbackRating.LIKE, ) ) or 0 @@ -1164,7 +1190,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", - MessageFeedback.rating == "dislike", + MessageFeedback.rating == FeedbackRating.DISLIKE, ) ) or 0 @@ -1179,7 +1205,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", - MessageFeedback.rating == "like", + MessageFeedback.rating == FeedbackRating.LIKE, ) ) or 0 @@ -1190,7 +1216,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", - MessageFeedback.rating == "dislike", + MessageFeedback.rating == FeedbackRating.DISLIKE, ) ) or 0 @@ -1359,8 +1385,10 @@ class Message(Base): ) error: Mapped[str | None] = mapped_column(LongText) message_metadata: Mapped[str | None] = mapped_column(LongText) - invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True) + from_source: Mapped[ConversationFromSource] = mapped_column( + EnumText(ConversationFromSource, length=255), nullable=False + ) from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) @@ -1713,8 +1741,8 @@ class MessageFeedback(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - rating: Mapped[str] = mapped_column(String(255), nullable=False) - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False) + from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False) content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) @@ -1761,13 +1789,15 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) + belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column( + EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None + ) url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( @@ -1821,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase): sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) source: Mapped[str] = mapped_column(LongText, nullable=False) @@ -2071,7 +2103,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(sa.DateTime, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -2097,7 +2129,7 @@ class UploadFile(Base): # The `server_default` serves as a fallback mechanism. id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False) key: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) size: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -2141,7 +2173,7 @@ class UploadFile(Base): self, *, tenant_id: str, - storage_type: str, + storage_type: StorageType, key: str, name: str, size: int, @@ -2206,7 +2238,7 @@ class MessageChain(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False) input: Mapped[str | None] = mapped_column(LongText, nullable=True) output: Mapped[str | None] = mapped_column(LongText, nullable=True) created_at: Mapped[datetime] = mapped_column( @@ -2381,7 +2413,7 @@ class Tag(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - type: Mapped[str] = mapped_column(String(16), nullable=False) + type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -2466,7 +2498,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/provider.py b/api/models/provider.py index 7cefdbaba5..afeee20b1e 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -13,6 +13,7 @@ from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db +from .enums import CredentialSourceType, PaymentStatus from .types import EnumText, LongText, StringUUID @@ -209,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -237,7 +238,9 @@ class ProviderOrder(TypeBase): quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) currency: Mapped[str | None] = mapped_column(String(40)) total_amount: Mapped[int | None] = mapped_column(sa.Integer) - payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'")) + payment_status: Mapped[PaymentStatus] = mapped_column( + EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'") + ) paid_at: Mapped[datetime | None] = mapped_column(DateTime) pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) refunded_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -300,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) + credential_source_type: Mapped[CredentialSourceType | None] = mapped_column( + EnumText(CredentialSourceType, length=40), nullable=True, default=None + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index c09f054e7d..63b27b9413 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolProviderType, + WorkflowToolParameterConfiguration, +) from .base import TypeBase from .engine import db from .model import Account, App, Tenant -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from core.entities.mcp_provider import MCPProviderEntity @@ -141,7 +145,9 @@ class ApiToolProvider(TypeBase): icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema: Mapped[str] = mapped_column(LongText, nullable=False) - schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) + schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column( + EnumText(ApiProviderSchemaType, length=40), nullable=False + ) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -208,7 +214,7 @@ class ToolLabelBinding(TypeBase): # tool id tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # label name label_name: Mapped[str] = mapped_column(String(40), nullable=False) @@ -386,7 +392,7 @@ class ToolModelInvoke(TypeBase): # provider provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # tool name tool_name: Mapped[str] = mapped_column(String(128), nullable=False) # invoke parameters diff --git a/api/models/workflow.py b/api/models/workflow.py index 9bb249481f..334ec42058 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,3 +1,4 @@ +import copy import json import logging from collections.abc import Generator, Mapping, Sequence @@ -22,14 +23,14 @@ from sqlalchemy import ( from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated -from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey from dify_graph.file.constants import maybe_file_object from dify_graph.file.models import File from dify_graph.variables import utils as variable_utils @@ -302,26 +303,40 @@ class Workflow(Base): # bug def features(self) -> str: """ Convert old features structure to new features structure. + + This property avoids rewriting the underlying JSON when normalization + produces no effective change, to prevent marking the row dirty on read. """ if not self._features: return self._features - features = json.loads(self._features) - if features.get("file_upload", {}).get("image", {}).get("enabled", False): - image_enabled = True - image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) - image_transfer_methods = features["file_upload"]["image"].get( - "transfer_methods", ["remote_url", "local_file"] - ) - features["file_upload"]["enabled"] = image_enabled - features["file_upload"]["number_limits"] = image_number_limits - features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods - features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) - features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( - "allowed_file_extensions", [] - ) - del features["file_upload"]["image"] - self._features = json.dumps(features) + # Parse once and deep-copy before normalization to detect in-place changes. + original_dict = self._decode_features_payload(self._features) + if original_dict is None: + return self._features + + # Fast-path: if the legacy file_upload.image.enabled shape is absent, skip + # deep-copy and normalization entirely and return the stored JSON. + file_upload_payload = original_dict.get("file_upload") + if not isinstance(file_upload_payload, dict): + return self._features + file_upload = cast(dict[str, Any], file_upload_payload) + + image_payload = file_upload.get("image") + if not isinstance(image_payload, dict): + return self._features + image = cast(dict[str, Any], image_payload) + if "enabled" not in image: + return self._features + + normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict)) + + if normalized_dict == original_dict: + # No effective change; return stored JSON unchanged. + return self._features + + # Normalization changed the payload: persist the normalized JSON. + self._features = json.dumps(normalized_dict) return self._features @features.setter @@ -332,6 +347,44 @@ class Workflow(Base): # bug def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} + @property + def serialized_features(self) -> str: + """Return the stored features JSON without triggering compatibility rewrites.""" + return self._features + + @property + def normalized_features_dict(self) -> dict[str, Any]: + """Decode features with legacy normalization without mutating the model state.""" + if not self._features: + return {} + + features = self._decode_features_payload(self._features) + return self._normalize_features_payload(features) if features is not None else {} + + @staticmethod + def _decode_features_payload(features: str) -> dict[str, Any] | None: + """Decode workflow features JSON when it contains an object payload.""" + payload = json.loads(features) + return cast(dict[str, Any], payload) if isinstance(payload, dict) else None + + @staticmethod + def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]: + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = True + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) + features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( + "allowed_file_extensions", [] + ) + del features["file_upload"]["image"] + + return features + def walk_nodes( self, specific_node_type: NodeType | None = None ) -> Generator[tuple[str, Mapping[str, Any]], None, None]: @@ -517,6 +570,31 @@ class Workflow(Base): # bug ) self._environment_variables = environment_variables_json + @staticmethod + def normalize_environment_variable_mappings( + mappings: Sequence[Mapping[str, Any]], + ) -> list[dict[str, Any]]: + """Convert masked secret placeholders into the draft hidden sentinel. + + Regular draft sync requests should preserve existing secrets without shipping + plaintext values back from the client. The dedicated restore endpoint now + copies published secrets server-side, so draft sync only needs to normalize + the UI mask into `HIDDEN_VALUE`. + """ + masked_secret_value = encrypter.full_mask_token() + normalized_mappings: list[dict[str, Any]] = [] + + for mapping in mappings: + normalized_mapping = dict(mapping) + if ( + normalized_mapping.get("value_type") == SegmentType.SECRET.value + and normalized_mapping.get("value") == masked_secret_value + ): + normalized_mapping["value"] = HIDDEN_VALUE + normalized_mappings.append(normalized_mapping) + + return normalized_mappings + def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict: environment_variables = list(self.environment_variables) environment_variables = [ @@ -564,6 +642,12 @@ class Workflow(Base): # bug ensure_ascii=False, ) + def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None: + """Copy stored variable JSON directly for same-tenant restore flows.""" + self._environment_variables = source_workflow._environment_variables + self._conversation_variables = source_workflow._conversation_variables + self._rag_pipeline_variables = source_workflow._rag_pipeline_variables + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) @@ -936,8 +1020,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata: datasource_info = execution_metadata["datasource_info"] extras["icon"] = datasource_info.get("icon") - elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata: - trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {} + elif ( + self.node_type == TRIGGER_PLUGIN_NODE_TYPE + and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata + ): + trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {} provider_id = trigger_info.get("provider_id") if provider_id: extras["icon"] = TriggerManager.get_trigger_plugin_icon( @@ -1134,7 +1221,9 @@ class WorkflowAppLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False + ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1214,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase): log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True + ) run_version: Mapped[str] = mapped_column(String(255), nullable=False) - run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), nullable=False + ) run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 31b778ab8c..6ef98068e6 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.13.1" +version = "1.13.2" requires-python = ">=3.11,<3.13" dependencies = [ @@ -8,7 +8,7 @@ dependencies = [ "arize-phoenix-otel~=0.15.0", "azure-identity==1.25.3", "beautifulsoup4==4.14.3", - "boto3==1.42.68", + "boto3==1.42.73", "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.6.2", @@ -23,7 +23,7 @@ dependencies = [ "gevent~=25.9.1", "gmpy2~=2.3.0", "google-api-core>=2.19.1", - "google-api-python-client==2.192.0", + "google-api-python-client==2.193.0", "google-auth>=2.47.0", "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", @@ -40,7 +40,7 @@ dependencies = [ "numpy~=1.26.4", "openpyxl~=3.1.5", "opik~=1.10.37", - "litellm==1.82.2", # Pinned to avoid madoka dependency issue + "litellm==1.82.6", # Pinned to avoid madoka dependency issue "opentelemetry-api==1.28.0", "opentelemetry-distro==0.49b0", "opentelemetry-exporter-otlp==1.28.0", @@ -72,13 +72,14 @@ dependencies = [ "pyyaml~=6.0.1", "readabilipy~=0.3.0", "redis[hiredis]~=7.3.0", - "resend~=2.23.0", - "sentry-sdk[flask]~=2.54.0", + "resend~=2.26.0", + "sentry-sdk[flask]~=2.55.0", "sqlalchemy~=2.0.29", - "starlette==0.52.1", + "starlette==1.0.0", "tiktoken~=0.12.0", "transformers~=5.3.0", "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", + "pypandoc~=1.13", "yarl~=1.23.0", "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", @@ -91,7 +92,7 @@ dependencies = [ "apscheduler>=3.11.0", "weave>=0.52.16", "fastopenapi[flask]>=0.7.0", - "bleach~=6.2.0", + "bleach~=6.3.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -118,7 +119,7 @@ dev = [ "ruff~=0.15.5", "pytest~=9.0.2", "pytest-benchmark~=5.2.3", - "pytest-cov~=7.0.0", + "pytest-cov~=7.1.0", "pytest-env~=1.6.0", "pytest-mock~=3.15.1", "testcontainers~=4.14.1", @@ -173,7 +174,7 @@ dev = [ "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.55.0", + "pyrefly>=0.57.1", ] ############################################################ @@ -202,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] # Required by vector store clients ############################################################ vdb = [ - "alibabacloud_gpdb20160503~=3.8.0", + "alibabacloud_gpdb20160503~=5.1.0", "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", "clickhouse-connect~=0.14.1", diff --git a/api/pytest.ini b/api/pytest.ini index 588dafe7eb..4d5d0ab6e0 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,6 +1,6 @@ [pytest] pythonpath = . -addopts = --cov=./api --cov-report=json --import-mode=importlib +addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py index 13d2f24ca0..cf223f6e9e 100644 --- a/api/schedule/check_upgradable_plugin_task.py +++ b/api/schedule/check_upgradable_plugin_task.py @@ -3,6 +3,7 @@ import math import time import click +from sqlalchemy import select import app from core.helper.marketplace import fetch_global_plugin_manifest @@ -28,17 +29,15 @@ def check_upgradable_plugin_task(): now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green")) - strategies = ( - db.session.query(TenantPluginAutoUpgradeStrategy) - .where( + strategies = db.session.scalars( + select(TenantPluginAutoUpgradeStrategy).where( TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day, TenantPluginAutoUpgradeStrategy.upgrade_time_of_day < now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL, TenantPluginAutoUpgradeStrategy.strategy_setting != TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED, ) - .all() - ) + ).all() total_strategies = len(strategies) click.echo(click.style(f"Total strategies: {total_strategies}", fg="green")) diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 2b74fb2dd0..04c954875f 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -2,7 +2,7 @@ import datetime import time import click -from sqlalchemy import text +from sqlalchemy import select, text from sqlalchemy.exc import SQLAlchemyError import app @@ -19,14 +19,12 @@ def clean_embedding_cache_task(): thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: try: - embedding_ids = ( - db.session.query(Embedding.id) + embedding_ids = db.session.scalars( + select(Embedding.id) .where(Embedding.created_at < thirty_days_ago) .order_by(Embedding.created_at.desc()) .limit(100) - .all() - ) - embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] + ).all() except SQLAlchemyError: raise if embedding_ids: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index d9fb6a24f1..0b0fc1b229 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -3,7 +3,7 @@ import time from typing import TypedDict import click -from sqlalchemy import func, select +from sqlalchemy import func, select, update from sqlalchemy.exc import SQLAlchemyError import app @@ -51,7 +51,7 @@ def clean_unused_datasets_task(): try: # Subquery for counting new documents document_subquery_new = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + select(Document.dataset_id, func.count(Document.id).label("document_count")) .where( Document.indexing_status == "completed", Document.enabled == True, @@ -64,7 +64,7 @@ def clean_unused_datasets_task(): # Subquery for counting old documents document_subquery_old = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + select(Document.dataset_id, func.count(Document.id).label("document_count")) .where( Document.indexing_status == "completed", Document.enabled == True, @@ -142,8 +142,8 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # Update document - db.session.query(Document).filter_by(dataset_id=dataset.id).update( - {Document.enabled: False} + db.session.execute( + update(Document).where(Document.dataset_id == dataset.id).values(enabled=False) ) db.session.commit() click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green")) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index ed46c1c70a..6ceb3ef856 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -1,12 +1,14 @@ import time import click +from sqlalchemy import func, select import app from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -20,7 +22,7 @@ def create_tidb_serverless_task(): try: # check the number of idle tidb serverless idle_tidb_serverless_number = ( - db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count() + db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0 ) if idle_tidb_serverless_number >= tidb_serverless_number: break @@ -56,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index d738bf46fa..8479cdfb0c 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -49,16 +49,18 @@ def mail_clean_document_notify_task(): if plan != CloudPlan.SANDBOX: knowledge_details = [] # check tenant - tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first() + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) if not tenant: continue # check current owner - current_owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) ) if not current_owner_join: continue - account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first() + account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) if not account: continue @@ -71,7 +73,7 @@ def mail_clean_document_notify_task(): ) for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) if dataset: document_count = len(document_ids) knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5..10003b1b97 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..69c7c0c95a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -241,7 +241,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +257,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b5..2e1b723e82 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b0768..6e183b70e3 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b002706931..c9e5610aea 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d1..cbdc908690 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 5ab47c799a..70d4ce1ee6 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -335,7 +335,11 @@ class BillingService: # Redis returns bytes, decode to string and parse JSON json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value plan_dict = json.loads(json_str) + # NOTE (hj24): New billing versions may return timestamp as str, and validate_python + # in non-strict mode will coerce it to the expected int type. + # To preserve compatibility, always keep non-strict mode here and avoid strict mode. subscription_plan = subscription_adapter.validate_python(plan_dict) + # NOTE END tenant_plans[tenant_id] = subscription_plan except Exception: logger.exception( diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571..2894826935 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..969ca68545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.file import helpers as file_helpers from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -227,7 +228,7 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid @@ -253,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -348,7 +352,7 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -716,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -952,8 +956,8 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error @@ -975,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -990,9 +994,9 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: @@ -1017,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1028,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1088,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1439,7 +1443,7 @@ class DocumentService: .filter( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, - Document.doc_form != "qa_model", # Skip qa_model documents + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .update({Document.need_summary: need_summary}, synchronize_session=False) ) @@ -1906,8 +1910,8 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model @@ -2039,7 +2043,7 @@ class DocumentService: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = knowledge_config.doc_form + document.doc_form = IndexStructureType(knowledge_config.doc_form) document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch @@ -2639,7 +2643,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = document_data.doc_form + document.doc_form = IndexStructureType(document_data.doc_form) db.session.add(document) db.session.commit() # update document segment @@ -2688,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2711,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3100,7 +3104,7 @@ class DocumentService: class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") if not args["answer"].strip(): @@ -3124,7 +3128,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3157,7 +3161,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] @@ -3207,7 +3211,7 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3229,9 +3233,9 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( texts=[content + segment_item["answer"]] )[0] @@ -3254,7 +3258,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count @@ -3321,7 +3325,7 @@ class SegmentService: content = args.content or segment.content if segment.content == content: segment.word_count = len(content) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3344,7 +3348,7 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3381,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3408,7 +3412,7 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3418,7 +3422,7 @@ class SegmentService: ) # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: @@ -3435,7 +3439,7 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3448,7 +3452,7 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3480,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( @@ -3786,7 +3790,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3849,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index 1a1cbbb450..e7473d371b 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -7,6 +7,7 @@ from flask import Response from sqlalchemy import or_ from extensions.ext_database import db +from models.enums import FeedbackRating from models.model import Account, App, Conversation, Message, MessageFeedback @@ -100,7 +101,7 @@ class FeedbackService: "ai_response": message.answer[:500] + "..." if len(message.answer) > 500 else message.answer, # Truncate long responses - "feedback_rating": "👍" if feedback.rating == "like" else "👎", + "feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎", "feedback_rating_raw": feedback.rating, "feedback_comment": feedback.content or "", "feedback_source": feedback.from_source, diff --git a/api/services/file_service.py b/api/services/file_service.py index ecb30faaa8..a7060f3b92 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -23,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account @@ -93,7 +94,7 @@ class FileService: # save file to db upload_file = UploadFile( tenant_id=current_tenant_id or "", - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=filename, size=file_size, @@ -152,7 +153,7 @@ class FileService: # save file to db upload_file = UploadFile( tenant_id=tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=text_name, size=len(text), diff --git a/api/services/message_service.py b/api/services/message_service.py index 789b6c2f8c..fc87802f51 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( @@ -172,7 +173,7 @@ class MessageService: app_model: App, message_id: str, user: Union[Account, EndUser] | None, - rating: str | None, + rating: FeedbackRating | None, content: str | None, ): if not user: @@ -197,7 +198,7 @@ class MessageService: message_id=message.id, rating=rating, content=content, - from_source=("user" if isinstance(user, EndUser) else "admin"), + from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 2133dc5b3a..bf3b6db3ed 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -103,9 +104,9 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True if config_from == "predefined-model": - credential_source_type = "provider" + credential_source_type = CredentialSourceType.PROVIDER else: - credential_source_type = "custom_model" + credential_source_type = CredentialSourceType.CUSTOM_MODEL # Get load balancing configurations load_balancing_configs = ( @@ -421,7 +422,11 @@ class ModelLoadBalancingService: raise ValueError("Invalid load balancing config name") if credential_id: - credential_source = "provider" if config_from == "predefined-model" else "custom_model" + credential_source = ( + CredentialSourceType.PROVIDER + if config_from == "predefined-model" + else CredentialSourceType.CUSTOM_MODEL + ) assert credential_record is not None load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 571ca6c7a6..f996db11dc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ - def get_pipeline_template_detail(self, template_id: str): + def get_pipeline_template_detail(self, template_id: str) -> dict | None: + result: dict | None try: result = self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: @@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict: """ Fetch pipeline template detail from dify official. - :param template_id: Pipeline ID - :return: + + :param template_id: Pipeline template ID + :return: Template detail dict + :raises ValueError: When upstream returns a non-200 status code """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates/{template_id}" response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: - return None + raise ValueError( + "fetch pipeline template detail failed," + + f" status_code: {response.status_code}," + + f" response: {response.text[:1000]}" + ) data: dict = response.json() return data diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index ecee562c93..296b9f0890 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -79,10 +79,11 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader +from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) @@ -117,13 +118,21 @@ class RagPipelineService: def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: """ Get pipeline template detail. + :param template_id: template id - :return: + :param type: template type, "built-in" or "customized" + :return: template detail dict, or None if not found """ if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + if built_in_result is None: + logger.warning( + "pipeline template retrieval returned empty result, template_id: %s, mode: %s", + template_id, + mode, + ) return built_in_result else: mode = "customized" @@ -226,6 +235,21 @@ class RagPipelineService: return workflow + def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: + """Fetch a published workflow snapshot by ID for restore operations.""" + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.id == workflow_id, + ) + .first() + ) + if workflow and workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError("source workflow must be published") + return workflow + def get_all_published_workflow( self, *, @@ -319,6 +343,42 @@ class RagPipelineService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + pipeline: Pipeline, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published pipeline workflow snapshot into the draft workflow. + + Pipelines reuse the shared draft-restore field copy helper, but still own + the pipeline-specific flush/link step that wires a newly created draft + back onto ``pipeline.workflow_id``. + """ + source_workflow = self.get_published_workflow_by_id(pipeline=pipeline, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None), + ) + + if is_new_draft: + db.session.add(draft_workflow) + db.session.flush() + pipeline.workflow_id = draft_workflow.id + + db.session.commit() + + return draft_workflow + def publish_workflow( self, *, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d3..fd66d55c1a 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 1d0aafd5fd..215a8c8528 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,6 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -79,9 +80,9 @@ class RagPipelineTransformService: pipeline = self._create_pipeline(pipeline_yaml) # save chunk structure to dataset - if doc_form == "hierarchical_model": + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset.chunk_structure = "hierarchical_model" - elif doc_form == "text_model": + elif doc_form == IndexStructureType.PARAGRAPH_INDEX: dataset.chunk_structure = "text_model" else: raise ValueError("Unsupported doc form") @@ -101,38 +102,38 @@ class RagPipelineTransformService: def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} - if doc_form == "text_model": + if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") - elif doc_form == "hierarchical_model": + elif doc_form == IndexStructureType.PARENT_CHILD_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml @@ -169,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 13a6363bc3..ed7a33feae 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,6 +12,8 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.model_entities import ModelType @@ -30,7 +32,7 @@ class SummaryIndexService: def generate_summary_for_segment( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> tuple[str, LLMUsage]: """ Generate summary for a single segment. @@ -139,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -600,7 +602,7 @@ class SummaryIndexService: def generate_and_vectorize_summary( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> DocumentSegmentSummary: """ Generate summary for a segment and vectorize it. @@ -705,7 +707,7 @@ class SummaryIndexService: def generate_summaries_for_document( dataset: Dataset, document: DatasetDocument, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, segment_ids: list[str] | None = None, only_parent_chunks: bool = False, ) -> list[DocumentSegmentSummary]: @@ -723,7 +725,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -850,7 +852,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -888,7 +890,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -980,7 +982,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1011,7 +1013,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf4..70bf7f16f2 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding @@ -83,7 +84,7 @@ class TagService: raise ValueError("Tag name already exists") tag = Tag( name=args["name"], - type=args["type"], + type=TagType(args["type"]), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index dc883f0daa..408b1c22d1 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,10 +1,10 @@ import json import logging -from collections.abc import Mapping from typing import Any, cast from httpx import get from sqlalchemy import select +from typing_extensions import TypedDict from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) +class ApiSchemaParseResult(TypedDict): + schema_type: str + parameters_schema: list[dict[str, Any]] + credentials_schema: list[dict[str, Any]] + warning: dict[str, str] + + class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> Mapping[str, Any]: + def parser_api_schema(schema: str) -> ApiSchemaParseResult: """ parse api schema to tool bundle """ @@ -71,7 +78,7 @@ class ApiToolManageService: ] return cast( - Mapping, + ApiSchemaParseResult, jsonable_encoder( { "schema_type": schema_type, diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 0be106f597..deb26438a8 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError +from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.utils.encryption import ProviderConfigEncrypter from models.tools import MCPToolProvider @@ -681,7 +682,7 @@ class MCPToolManageService: raise ValueError(f"Failed to re-connect MCP server: {e}") from e def _build_tool_provider_response( - self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list + self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool] ) -> ToolProviderApiEntity: """Build API response for tool provider.""" user = db_provider.load_user() @@ -703,7 +704,7 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_url} already exists") if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - raise + raise error def _is_valid_url(self, url: str) -> bool: """Validate URL format.""" diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..bb94a03ba3 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document @@ -45,7 +45,7 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -112,7 +112,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +197,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +237,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +252,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/services/website_service.py b/api/services/website_service.py index 15ec4657d9..b2917ba152 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -9,7 +9,7 @@ import httpx from flask_login import current_user from core.helper import encrypter -from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp +from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData from core.rag.extractor.watercrawl.provider import WaterCrawlProvider from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -216,8 +216,10 @@ class WebsiteService: "max_depth": request.options.max_depth, "use_sitemap": request.options.use_sitemap, } - return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( - url=request.url, options=options + return dict( + WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( + url=request.url, options=options + ) ) @classmethod @@ -270,13 +272,13 @@ class WebsiteService: @classmethod def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) - result = firecrawl_app.check_crawl_status(job_id) - crawl_status_data = { - "status": result.get("status", "active"), + result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id) + crawl_status_data: dict[str, Any] = { + "status": result["status"], "job_id": job_id, - "total": result.get("total", 0), - "current": result.get("current", 0), - "data": result.get("data", []), + "total": result["total"] or 0, + "current": result["current"] or 0, + "data": result["data"], } if crawl_status_data["status"] == "completed": website_crawl_time_cache_key = f"website_crawl_{job_id}" @@ -289,8 +291,8 @@ class WebsiteService: return crawl_status_data @classmethod - def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: - return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id) + def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]: + return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)) @classmethod def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: @@ -343,7 +345,7 @@ class WebsiteService: @classmethod def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: - crawl_data: list[dict[str, Any]] | None = None + crawl_data: list[FirecrawlDocumentData] | None = None file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): stored_data = storage.load_once(file_key) @@ -352,19 +354,22 @@ class WebsiteService: else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) result = firecrawl_app.check_crawl_status(job_id) - if result.get("status") != "completed": + if result["status"] != "completed": raise ValueError("Crawl job is not completed") - crawl_data = result.get("data") + crawl_data = result["data"] if crawl_data: for item in crawl_data: - if item.get("source_url") == url: + if item["source_url"] == url: return dict(item) return None @classmethod - def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: - return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + def _get_watercrawl_url_data( + cls, job_id: str, url: str, api_key: str, config: dict[str, Any] + ) -> dict[str, Any] | None: + result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + return dict(result) if result is not None else None @classmethod def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: @@ -416,8 +421,8 @@ class WebsiteService: def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) params = {"onlyMainContent": request.only_main_content} - return firecrawl_app.scrape_url(url=request.url, params=params) + return dict(firecrawl_app.scrape_url(url=request.url, params=params)) @classmethod - def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: - return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url) + def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: + return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 006483fe97..f0596e44c8 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,7 @@ import json -from typing import Any, TypedDict +from typing import Any + +from typing_extensions import TypedDict from core.app.app_config.entities import ( DatasetEntity, @@ -34,6 +36,17 @@ class _NodeType(TypedDict): data: dict[str, Any] +class _EdgeType(TypedDict): + id: str + source: str + target: str + + +class WorkflowGraph(TypedDict): + nodes: list[_NodeType] + edges: list[_EdgeType] + + class WorkflowConverter: """ App Convert to Workflow Mode @@ -107,7 +120,7 @@ class WorkflowConverter: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph: dict[str, Any] = {"nodes": [], "edges": []} + graph: WorkflowGraph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -385,7 +398,7 @@ class WorkflowConverter: self, original_app_mode: AppMode, new_app_mode: AppMode, - graph: dict, + graph: WorkflowGraph, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, file_upload: FileUploadConfig | None = None, @@ -595,7 +608,7 @@ class WorkflowConverter: "data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"}, } - def _create_edge(self, source: str, target: str): + def _create_edge(self, source: str, target: str) -> _EdgeType: """ Create Edge :param source: source node id @@ -604,7 +617,7 @@ class WorkflowConverter: """ return {"id": f"{source}-{target}", "source": source, "target": target} - def _append_node(self, graph: dict[str, Any], node: _NodeType): + def _append_node(self, graph: WorkflowGraph, node: _NodeType): """ Append Node to Graph diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 7147fe1eab..9489618762 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -5,6 +5,7 @@ from typing import Any from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from typing_extensions import TypedDict from dify_graph.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun @@ -14,6 +15,10 @@ from services.plugin.plugin_service import PluginService from services.workflow.entities import TriggerMetadata +class LogViewDetails(TypedDict): + trigger_metadata: dict[str, Any] | None + + # Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it class LogView: """Lightweight wrapper for WorkflowAppLog with computed details. @@ -22,12 +27,12 @@ class LogView: - Proxies all other attributes to the underlying `WorkflowAppLog` """ - def __init__(self, log: WorkflowAppLog, details: dict | None): + def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None): self.log = log self.details_ = details @property - def details(self) -> dict | None: + def details(self) -> LogViewDetails | None: return self.details_ def __getattr__(self, name): diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index fb1a3f30c0..f124e137c3 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -35,7 +35,7 @@ from factories.variable_factory import build_segment, segment_to_variable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation -from models.enums import DraftVariableType +from models.enums import ConversationFromSource, DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -601,7 +601,7 @@ class WorkflowDraftVariableService: system_instruction_tokens=0, status="normal", invoke_from=InvokeFrom.DEBUGGER, - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account_id, ) diff --git a/api/services/workflow_restore.py b/api/services/workflow_restore.py new file mode 100644 index 0000000000..083235d228 --- /dev/null +++ b/api/services/workflow_restore.py @@ -0,0 +1,58 @@ +"""Shared helpers for restoring published workflow snapshots into drafts. + +Both app workflows and RAG pipeline workflows restore the same workflow fields +from a published snapshot into a draft. Keeping that field-copy logic in one +place prevents the two restore paths from drifting when we add or adjust draft +state in the future. Restore stays within a tenant, so we can safely reuse the +serialized workflow storage blobs without decrypting and re-encrypting secrets. +""" + +from collections.abc import Callable +from datetime import datetime + +from models import Account +from models.workflow import Workflow, WorkflowType + +UpdatedAtFactory = Callable[[], datetime] + + +def apply_published_workflow_snapshot_to_draft( + *, + tenant_id: str, + app_id: str, + source_workflow: Workflow, + draft_workflow: Workflow | None, + account: Account, + updated_at_factory: UpdatedAtFactory, +) -> tuple[Workflow, bool]: + """Copy a published workflow snapshot into a draft workflow record. + + The caller remains responsible for source lookup, validation, flushing, and + post-commit side effects. This helper only centralizes the shared draft + creation/update semantics used by both restore entry points. Features are + copied from the stored JSON payload so restore does not normalize and dirty + the published source row before the caller commits. + """ + if not draft_workflow: + workflow_type = ( + source_workflow.type.value if isinstance(source_workflow.type, WorkflowType) else source_workflow.type + ) + draft_workflow = Workflow( + tenant_id=tenant_id, + app_id=app_id, + type=workflow_type, + version=Workflow.VERSION_DRAFT, + graph=source_workflow.graph, + features=source_workflow.serialized_features, + created_by=account.id, + ) + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + return draft_workflow, True + + draft_workflow.graph = source_workflow.graph + draft_workflow.features = source_workflow.serialized_features + draft_workflow.updated_by = account.id + draft_workflow.updated_at = updated_at_factory() + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + + return draft_workflow, False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e13cdd5f27..66976058c0 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -63,7 +63,12 @@ from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeEx from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService from services.enterprise.plugin_manager_service import PluginCredentialType -from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError +from services.errors.app import ( + IsDraftWorkflowError, + TriggerNodeLimitExceededError, + WorkflowHashNotEqualError, + WorkflowNotFoundError, +) from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError @@ -75,6 +80,7 @@ from .human_input_delivery_test_service import ( HumanInputDeliveryTestService, ) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService +from .workflow_restore import apply_published_workflow_snapshot_to_draft class WorkflowService: @@ -279,6 +285,43 @@ class WorkflowService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + app_model: App, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published workflow snapshot into the draft workflow. + + Secret environment variables are copied server-side from the selected + published workflow so the normal draft sync flow stays stateless. + """ + source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict) + self.validate_graph_structure(graph=source_workflow.graph_dict) + + draft_workflow = self.get_draft_workflow(app_model=app_model) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=naive_utc_now, + ) + + if is_new_draft: + db.session.add(draft_workflow) + + db.session.commit() + app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow) + + return draft_workflow + def publish_workflow( self, *, diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2..dafa36cc34 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf03454..c734e1321b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af95..c9aa8fadb7 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b00..41cf7ccbf6 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d62..2c07fe0f31 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e74..f41da1d373 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee00919..dd58378e0e 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,6 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -109,7 +110,7 @@ def batch_create_segment_to_index_task( df = pd.read_csv(file_path) content = [] for _, row in df.iterrows(): - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: data = {"content": row.iloc[0], "answer": row.iloc[1]} else: data = {"content": row.iloc[0]} @@ -119,7 +120,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], @@ -159,7 +160,7 @@ def batch_create_segment_to_index_task( status="completed", completed_at=naive_utc_now(), ) - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: segment_document.answer = segment["answer"] segment_document.word_count += len(segment["answer"]) word_count_change += segment_document.word_count diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e05d63426c..23a80fa106 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,6 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -126,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status @@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) if ( document.indexing_status == IndexingStatus.COMPLETED - and document.doc_form != "qa_model" + and document.doc_form != IndexStructureType.QA_INDEX and document.need_summary is True ): try: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc..e3d82d2851 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 39c2f4103e..6f490ab7ea 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,6 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -52,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " @@ -106,7 +107,7 @@ def regenerate_summary_index_task( ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived - DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .all() @@ -209,7 +210,7 @@ def regenerate_summary_index_task( for dataset_document in dataset_documents: # Skip qa_model documents - if dataset_document.doc_form == "qa_model": + if dataset_document.doc_form == IndexStructureType.QA_INDEX: continue try: diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 75ae1f6316..f8c7964805 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -179,7 +179,7 @@ def _record_trigger_failure_log( app_id=workflow.app_id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=created_by_role, created_by=created_by, ) diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index afb6938baa..d10e5ed13c 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -13,6 +13,7 @@ from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now from models import App, Tenant from models.account import Account, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -154,7 +155,7 @@ class TestChatMessageApiPermissions: re_sign_file_url_answer="", answer_tokens=0, provider_response_latency=0.0, - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=mock_account.id, feedbacks=[], diff --git a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py index 0f8b42e98b..309a0b015a 100644 --- a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py +++ b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py @@ -14,6 +14,7 @@ from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now from models import App, Tenant from models.account import Account, TenantAccountJoin, TenantAccountRole +from models.enums import FeedbackFromSource, FeedbackRating from models.model import AppMode, MessageFeedback from services.feedback_service import FeedbackService @@ -77,8 +78,8 @@ class TestFeedbackExportApi: app_id=app_id, conversation_id=conversation_id, message_id=message_id, - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content=None, from_end_user_id=str(uuid.uuid4()), from_account_id=None, @@ -90,8 +91,8 @@ class TestFeedbackExportApi: app_id=app_id, conversation_id=conversation_id, message_id=message_id, - rating="dislike", - from_source="admin", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.ADMIN, content="The response was not helpful", from_end_user_id=None, from_account_id=str(uuid.uuid4()), @@ -277,8 +278,8 @@ class TestFeedbackExportApi: # Verify service was called with correct parameters mock_export_feedbacks.assert_called_once_with( app_id=mock_app_model.id, - from_source="user", - rating="dislike", + from_source=FeedbackFromSource.USER, + rating=FeedbackRating.DISLIKE, has_comment=True, start_date="2024-01-01", end_date="2024-12-31", diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index b4e3a0e4de..db4bbc1ca1 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Session from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db +from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase): upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=storage_key, name="test_file.txt", size=1024, @@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file for other tenant (but don't add to cleanup list) upload_file_other = UploadFile( tenant_id=other_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key="other_tenant_key", name="other_file.txt", size=1024, diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 1d7b835fd2..a942690cbd 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -13,6 +13,7 @@ from unittest.mock import patch import pytest from extensions.ext_redis import redis_client +from models.enums import ApiTokenType from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken @@ -279,7 +280,7 @@ class TestEndToEndCacheFlow: test_token = ApiToken() test_token.id = "test-e2e-id" test_token.token = test_token_value - test_token.type = test_scope + test_token.type = ApiTokenType.APP test_token.app_id = "test-app" test_token.tenant_id = "test-tenant" test_token.last_used_at = None diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index b6aeb54cca..9d3a869691 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -13,6 +13,7 @@ from dify_graph.variables.types import SegmentType from dify_graph.variables.variables import StringVariable from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment from libs import datetime_utils from models.enums import CreatorUserRole @@ -347,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Create an upload file record upload_file = UploadFile( tenant_id=self._test_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_offload_{uuid.uuid4()}.json", name="test_offload.json", size=len(content_bytes), @@ -450,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Create upload file record upload_file = UploadFile( tenant_id=self._test_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_integration_{uuid.uuid4()}.txt", name="test_integration.txt", size=len(content_bytes), diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 988313e68d..bc83c6cc12 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ from sqlalchemy import delete from core.db.session_factory import session_factory from dify_graph.variables.segments import StringSegment +from extensions.storage.storage_type import StorageType from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -197,7 +198,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: with session_factory.create_session() as session: upload_file1 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file1.json", name="file1.json", size=1024, @@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: ) upload_file2 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file2.json", name="file2.json", size=2048, @@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit: with session_factory.create_session() as session: upload_file1 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file1.json", name="file1.json", size=1024, @@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit: ) upload_file2 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file2.json", name="file2.json", size=2048, diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a6717ada31..818ae46625 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -68,7 +68,7 @@ def init_tool_node(config: dict): return node -def test_tool_variable_invoke(): +def test_tool_variable_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", @@ -103,7 +103,7 @@ def test_tool_variable_invoke(): assert item.node_run_result.outputs.get("text") is not None -def test_tool_mixed_invoke(): +def test_tool_mixed_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 0bdd3bdc47..ef0ca4232d 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -165,8 +165,9 @@ class DifyTestContainers: # Start Dify Sandbox container for code execution environment # Dify Sandbox provides a secure environment for executing user code + # Use pinned version 0.2.12 to match production docker-compose configuration logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network) + self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 6f2e008d44..4f606dccb8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -13,7 +13,7 @@ from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin from models.account import AccountStatus, TenantAccountRole -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun from services.account_service import AccountService @@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C inputs={}, status="normal", mode=AppMode.CHAT, - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account_id, ) db_session.add(conversation) @@ -124,7 +124,7 @@ def _create_message( answer_price_unit=0.001, currency="USD", status="normal", - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account_id, workflow_run_id=workflow_run_id, inputs={"query": "Hello"}, diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py new file mode 100644 index 0000000000..6b51ec98bc --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -0,0 +1,342 @@ +"""Authenticated controller integration tests for console message APIs.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload +from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation: + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Test Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + account_id: str, + *, + created_at_offset_seconds: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=1, + message_unit_price=Decimal("0.0001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=1, + answer_unit_price=Decimal("0.0001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal("0.0002"), + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +class TestMessageValidators: + def test_chat_messages_query_validators(self) -> None: + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self) -> None: + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self) -> None: + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +def test_chat_message_list_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": str(uuid4())}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_chat_message_list_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0) + second = _create_message( + db_session_with_containers, + app.id, + conversation.id, + account.id, + created_at_offset_seconds=1, + ) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": conversation.id, "limit": 1}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["limit"] == 1 + assert payload["has_more"] is True + assert len(payload["data"]) == 1 + assert payload["data"][0]["id"] == second.id + + +def test_message_feedback_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": str(uuid4()), "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_message_feedback_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": message.id, "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + feedback = db_session_with_containers.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == message.id) + ) + assert feedback is not None + assert feedback.rating == FeedbackRating.LIKE + assert feedback.from_account_id == account.id + + +def test_message_annotation_count( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + db_session_with_containers.add( + MessageAnnotation( + app_id=app.id, + conversation_id=conversation.id, + message_id=message.id, + question="Q", + content="A", + account_id=account.id, + ) + ) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/annotations/count", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"count": 1} + + +def test_message_suggested_questions_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"data": ["q1", "q2"]} + + +@pytest.mark.parametrize( + ("exc", "expected_status", "expected_code"), + [ + (MessageNotExistsError(), 404, "not_found"), + (ConversationNotExistsError(), 404, "not_found"), + (ProviderTokenNotInitError(), 400, "provider_not_initialize"), + (QuotaExceededError(), 400, "provider_quota_exceeded"), + (ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"), + (SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"), + (Exception(), 500, "internal_server_error"), + ], +) +def test_message_suggested_questions_errors( + exc: Exception, + expected_status: int, + expected_code: str, + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + side_effect=exc, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == expected_status + payload = response.get_json() + assert payload is not None + assert payload["code"] == expected_code + + +def test_message_feedback_export_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/feedbacks/export", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"exported": True} + + +def test_message_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/messages/{message.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == message.id diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..963cfe53e5 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,334 @@ +"""Controller integration tests for console statistic routes.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageFeedback +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation( + db_session: Session, + app_id: str, + account_id: str, + *, + mode: AppMode, + created_at_offset_days: int = 0, +) -> Conversation: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Stats Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + *, + from_account_id: str | None, + from_end_user_id: str | None = None, + message_tokens: int = 1, + answer_tokens: int = 1, + total_price: Decimal = Decimal("0.01"), + provider_response_latency: float = 1.0, + created_at_offset_days: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=message_tokens, + message_unit_price=Decimal("0.001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=answer_tokens, + answer_unit_price=Decimal("0.001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=provider_response_latency, + total_price=total_price, + currency="USD", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=from_end_user_id, + from_account_id=from_account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +def _create_like_feedback( + db_session: Session, + app_id: str, + conversation_id: str, + message_id: str, + account_id: str, +) -> None: + db_session.add( + MessageFeedback( + app_id=app_id, + conversation_id=conversation_id, + message_id=message_id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + from_account_id=account_id, + ) + ) + db_session.commit() + + +def test_daily_message_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["message_count"] == 1 + + +def test_daily_conversation_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-conversations", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["conversation_count"] == 1 + + +def test_daily_terminals_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=None, + from_end_user_id=str(uuid4()), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-end-users", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["terminal_count"] == 1 + + +def test_daily_token_cost_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + message_tokens=40, + answer_tokens=60, + total_price=Decimal("0.02"), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/token-costs", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["data"][0]["token_count"] == 100 + assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02") + + +def test_average_session_interaction_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-session-interactions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["interactions"] == 2.0 + + +def test_user_satisfaction_rate_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + for _ in range(9): + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["rate"] == 100.0 + + +def test_average_response_time_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + provider_response_latency=1.234, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-response-time", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["latency"] == 1234.0 + + +def test_tokens_per_second_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + answer_tokens=31, + provider_response_latency=2.0, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/tokens-per-second", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["tps"] == 15.5 + + +def test_invalid_time_range( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "Invalid time" + + +def test_time_range_params_passed( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + import datetime + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + start = datetime.datetime.now() + end = datetime.datetime.now() + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse: + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + mock_parse.assert_called_once_with("something", "something", "UTC") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..f037ad77c0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,415 @@ +"""Authenticated controller integration tests for workflow draft variable APIs.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from dify_graph.variables.segments import StringSegment +from factories.variable_factory import segment_to_variable +from models import Workflow +from models.model import AppMode +from models.workflow import WorkflowDraftVariable +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_draft_workflow( + db_session: Session, + app_id: str, + tenant_id: str, + account_id: str, + *, + environment_variables: list | None = None, + conversation_variables: list | None = None, +) -> Workflow: + workflow = Workflow.new( + tenant_id=tenant_id, + app_id=app_id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=account_id, + environment_variables=environment_variables or [], + conversation_variables=conversation_variables or [], + rag_pipeline_variables=[], + ) + db_session.add(workflow) + db_session.commit() + return workflow + + +def _create_node_variable( + db_session: Session, + app_id: str, + user_id: str, + *, + node_id: str = "node_1", + name: str = "test_var", +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + user_id=user_id, + node_id=node_id, + name=name, + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + visible=True, + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _create_system_variable( + db_session: Session, app_id: str, user_id: str, name: str = "query" +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + user_id=user_id, + name=name, + value=StringSegment(value="system-value"), + node_execution_id=str(uuid.uuid4()), + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _build_environment_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, name], + name=name, + description=f"Environment variable {name}", + ) + + +def _build_conversation_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[CONVERSATION_VARIABLE_NODE_ID, name], + name=name, + description=f"Conversation variable {name}", + ) + + +def test_workflow_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"items": [], "total": 0} + + +def test_workflow_variable_collection_get_not_exist( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "draft_workflow_not_exist" + + +def test_workflow_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_node_variable(db_session_with_containers, app.id, account.id) + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var") + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + remaining = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + ) + ).all() + assert remaining == [] + + +def test_node_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other") + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [node_variable.id] + + +def test_node_variable_collection_get_invalid_node_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "invalid_param" + + +def test_node_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456") + target_id = target.id + untouched_id = untouched.id + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id)) + is None + ) + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id)) + is not None + ) + + +def test_variable_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "test_var" + + +def test_variable_api_get_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_variable_api_patch_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.patch( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + json={"name": "renamed_var"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "renamed_var" + + refreshed = db_session_with_containers.scalar( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id) + ) + assert refreshed is not None + assert refreshed.name == "renamed_var" + + +def test_variable_api_delete_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_variable_reset_api_put_success_returns_no_content_without_execution( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.put( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_conversation_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + conversation_variables=[_build_conversation_variable("session_name", "Alice")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/conversation-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["name"] for item in payload["items"]] == ["session_name"] + + created = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID, + ) + ).all() + assert len(created) == 1 + + +def test_system_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + variable = _create_system_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/system-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [variable.id] + + +def test_environment_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + environment_variables=[_build_environment_variable("api_key", "secret-value")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/environment-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["items"][0]["name"] == "api_key" + assert payload["items"][0]["value"] == "secret-value" diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..00309c25d6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,131 @@ +"""Controller integration tests for API key data source auth routes.""" + +import json +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.source import DataSourceApiKeyAuthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_api_key_auth_data_source( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert len(payload["sources"]) == 1 + assert payload["sources"][0]["provider"] == "custom_provider" + + +def test_get_api_key_auth_data_source_empty( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"sources": []} + + +def test_create_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + +def test_create_binding_failure( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth", + side_effect=ValueError("Invalid structure"), + ), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 500 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "auth_failed" + assert payload["message"] == "Invalid structure" + + +def test_delete_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{binding.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id) + ) + is None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 0000000000..81b5423261 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,120 @@ +"""Controller integration tests for console OAuth data source routes.""" + +from unittest.mock import MagicMock, patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.source import DataSourceOauthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_oauth_url_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + provider = MagicMock() + provider.get_authorization_url.return_value = "http://oauth.provider/auth" + + with ( + patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}), + patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None), + ): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/notion", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert tenant.id == account.current_tenant_id + assert response.status_code == 200 + assert response.get_json() == {"data": "http://oauth.provider/auth"} + provider.get_authorization_url.assert_called_once() + + +def test_get_oauth_url_invalid_provider( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/unknown_provider", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + +def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion") + + assert response.status_code == 302 + assert "error=Access%20denied" in response.location + + +def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None: + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123") + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.get_access_token.assert_called_once_with("auth_code_123") + + +def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid code"} + + +def test_sync_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceOauthBinding( + tenant_id=tenant.id, + access_token="test-access-token", + provider="notion", + source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []}, + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get( + f"/console/api/oauth/data-source/notion/{binding.id}/sync", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.sync_data_source.assert_called_once_with(binding.id) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_email_register.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 724c80f18c..879c337319 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for email register controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestEmailRegisterSendEmailApi: - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") @@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - mock_session_cls, app, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False mock_account = MagicMock() - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = mock_account feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), @@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi: assert response == {"result": "success", "data": "token-123"} mock_is_freeze.assert_called_once_with("invitee@example.com") mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) mock_extract_ip.assert_called_once() mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") @@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi: feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -114,7 +107,6 @@ class TestEmailRegisterResetApi: @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") @patch("controllers.console.auth.email_register.AccountService.login") @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") @@ -125,7 +117,6 @@ class TestEmailRegisterResetApi: mock_get_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_create_account, mock_login, mock_reset_login_rate, @@ -136,14 +127,10 @@ class TestEmailRegisterResetApi: token_pair = MagicMock() token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} mock_login.return_value = token_pair - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = None feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -159,19 +146,19 @@ class TestEmailRegisterResetApi: mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_forgot_password.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 8403777dc9..7b7393dade 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for forgot password controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestForgotPasswordSendEmailApi: - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - mock_session_cls, app, ): mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_email.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) controller_features = SimpleNamespace(is_allow_register=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( "controllers.console.auth.forgot_password.FeatureService.get_system_features", return_value=controller_features, @@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_send_email.assert_called_once_with( account=mock_account, email="user@example.com", @@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @@ -126,7 +120,6 @@ class TestForgotPasswordResetApi: mock_get_reset_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_update_account, app, ): @@ -134,12 +127,8 @@ class TestForgotPasswordResetApi: mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - wraps_features = SimpleNamespace(enable_email_password_login=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): @@ -157,20 +146,22 @@ class TestForgotPasswordResetApi: assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_update_account.assert_called_once() -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" + from unittest.mock import MagicMock + mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py similarity index 92% rename from api/tests/unit_tests/controllers/console/auth/test_oauth.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 6345c2ab23..a2f1328579 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -1,7 +1,10 @@ +"""Testcontainers integration tests for OAuth controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.mark.parametrize( ("github_config", "google_config", "expected_github", "expected_google"), @@ -64,10 +65,8 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_oauth_provider(self): @@ -131,10 +130,8 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def oauth_setup(self): @@ -190,15 +187,8 @@ class TestOAuthCallback: (KeyError("Missing key"), "OAuth process failed"), ], ) - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions( - self, mock_get_providers, mock_db, resource, app, exception, expected_error - ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - + def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): # Import the real requests module to create a proper exception import httpx @@ -258,7 +248,6 @@ class TestOAuthCallback: ) @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") @@ -269,7 +258,6 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - mock_db, mock_tenant_service, mock_account_service, resource, @@ -278,10 +266,6 @@ class TestOAuthCallback: account_status, expected_redirect, ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - mock_db.session.commit = MagicMock() mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -306,14 +290,12 @@ class TestOAuthCallback: @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") def test_should_activate_pending_account( self, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -338,12 +320,10 @@ class TestOAuthCallback: assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None - mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.redirect") @@ -352,7 +332,6 @@ class TestOAuthCallback: mock_redirect, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -414,6 +393,10 @@ class TestOAuthCallback: class TestAccountGeneration: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def user_info(self): return OAuthUserInfo(id="123", name="Test User", email="test@example.com") @@ -425,15 +408,10 @@ class TestAccountGeneration: return account @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.console.auth.oauth.Session") @patch("controllers.console.auth.oauth.Account") - @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account + self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account ): - # Mock db.engine for Session creation - mock_db.engine = MagicMock() - # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) @@ -443,15 +421,14 @@ class TestAccountGeneration: # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None - mock_session_instance = MagicMock() - mock_session.return_value.__enter__.return_value = mock_session_instance mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account - mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + mock_get_account.assert_called_once() - def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() first_result = MagicMock() first_result.scalar_one_or_none.return_value = None @@ -462,7 +439,7 @@ class TestAccountGeneration: result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert result == expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( @@ -478,10 +455,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_handle_account_generation_scenarios( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, @@ -519,10 +494,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_register_with_lowercase_email( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..2ef27133d8 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,365 @@ +"""Controller integration tests for console OAuth server routes.""" + +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + ensure_dify_setup, +) + + +def _build_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon_url", + client_id="test_client_id", + client_secret="test_secret", + app_label={"en-US": "Test App"}, + redirect_uris=["http://localhost/callback"], + scope="read,write", + ) + + +def test_oauth_provider_successful_post( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["app_icon"] == "icon_url" + assert payload["app_label"] == {"en-US": "Test App"} + assert payload["scope"] == "read,write" + + +def test_oauth_provider_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert "redirect_uri is invalid" in payload["message"] + + +def test_oauth_provider_invalid_client_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert "client_id is invalid" in payload["message"] + + +def test_oauth_authorize_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="auth_code_123", + ) as mock_sign, + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/authorize", + json={"client_id": "test_client_id"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"code": "auth_code_123"} + mock_sign.assert_called_once_with("test_client_id", account.id) + + +def test_oauth_token_authorization_code_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("access_123", "refresh_123"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "access_123", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "refresh_123", + } + + +def test_oauth_token_authorization_code_grant_missing_code( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "code is required" + + +def test_oauth_token_authorization_code_grant_invalid_secret( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "client_secret is invalid" + + +def test_oauth_token_authorization_code_grant_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "redirect_uri is invalid" + + +def test_oauth_token_refresh_token_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("new_access", "new_refresh"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "new_access", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "new_refresh", + } + + +def test_oauth_token_refresh_token_grant_missing_token( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "refresh_token is required" + + +def test_oauth_token_invalid_grant_type( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "invalid_grant"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "invalid grant_type" + + +def test_oauth_account_successful_retrieval( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account.avatar = "avatar_url" + db_session_with_containers.commit() + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token", + return_value=account, + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "name": "Test User", + "email": account.email, + "avatar": "avatar_url", + "interface_language": "en-US", + "timezone": "UTC", + } + + +def test_oauth_account_missing_authorization_header( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Authorization header is required"} + + +def test_oauth_account_invalid_authorization_header_format( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Invalid Authorization header format"} diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528e..8f9db287e3 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py new file mode 100644 index 0000000000..9e2084f393 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -0,0 +1,85 @@ +"""Shared helpers for authenticated console controller integration tests.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HEADER_NAME_CSRF_TOKEN +from libs.datetime_utils import naive_utc_now +from libs.token import _real_cookie_name, generate_csrf_token +from models import Account, DifySetup, Tenant, TenantAccountJoin +from models.account import AccountStatus, TenantAccountRole +from models.model import App, AppMode +from services.account_service import AccountService + + +def ensure_dify_setup(db_session: Session) -> None: + """Create a setup marker once so setup-protected console routes can be exercised.""" + if db_session.scalar(select(DifySetup).limit(1)) is not None: + return + + db_session.add(DifySetup(version=dify_config.project.version)) + db_session.commit() + + +def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: + """Create an initialized owner account with a current tenant.""" + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="Test User", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.initialized_at = naive_utc_now() + db_session.add(account) + db_session.commit() + + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + + db_session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + ) + db_session.commit() + + account.set_tenant_id(tenant.id) + account.timezone = "UTC" + db_session.commit() + + ensure_dify_setup(db_session) + return account, tenant + + +def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App: + """Create a minimal app row that can be loaded by get_app_model.""" + app = App( + tenant_id=tenant_id, + name="Test App", + mode=mode, + enable_site=True, + enable_api=True, + created_by=account_id, + ) + db_session.add(app) + db_session.commit() + return app + + +def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]: + """Attach console auth cookies/headers for endpoints guarded by login_required.""" + access_token = AccountService.get_account_jwt_token(account) + csrf_token = generate_csrf_token(account.id) + test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost") + return { + "Authorization": f"Bearer {access_token}", + HEADER_NAME_CSRF_TOKEN: csrf_token, + } diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 781e297fa4..00d7496a40 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -38,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration: name=f"Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived @@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, @@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=status, # Not completed enabled=True, archived=False, @@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -459,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) @@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cb7cd37a3f..8e70fc0bb0 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Session from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db +from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase): upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=storage_key, name="test_file.txt", size=1024, @@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file for other tenant (but don't add to cleanup list) upload_file_other = UploadFile( tenant_id=other_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key="other_tenant_key", name="other_file.txt", size=1024, diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 573f84cb0b..fb8d1808f9 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -7,6 +7,7 @@ from uuid import uuid4 from dify_graph.nodes.human_input.entities import FormDefinition, UserAction from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent from models.human_input import HumanInputForm, HumanInputFormStatus from models.model import App, Conversation, Message @@ -78,8 +79,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: introduction="", system_instruction="", status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, from_end_user_id=None, ) @@ -101,7 +102,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: answer_unit_price=Decimal("0.001"), provider_response_latency=0.5, currency="USD", - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, workflow_run_id=workflow_run_id, ) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py new file mode 100644 index 0000000000..a79208f649 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py @@ -0,0 +1,227 @@ +""" +Integration tests for Redis Streams broadcast channel implementation using TestContainers. + +This suite focuses on the semantics that differ from Redis Pub/Sub: +- Every active subscription should receive each newly published message. +- Each subscription should only observe messages published after its listener starts. +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + +class TestRedisStreamsBroadcastChannelIntegration: + """Integration tests for Redis Streams broadcast channel with a real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a StreamsBroadcastChannel instance with a real Redis client.""" + return StreamsBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_streams_topic_{uuid.uuid4()}" + + @staticmethod + def _start_subscription(subscription: Subscription) -> None: + """Start the background listener and confirm the subscription queue is empty.""" + assert subscription.receive(timeout=0.05) is None + + @staticmethod + def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes: + """Poll until a message is received or the timeout expires.""" + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + message = subscription.receive(timeout=0.1) + if message is not None: + return message + pytest.fail("Timed out waiting for a message") + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None: + """Closing an active subscription should terminate the iterator cleanly.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume() -> list[bytes]: + messages: list[bytes] = [] + consuming_event.set() + for message in subscription: + messages.append(message) + return messages + + with ThreadPoolExecutor(max_workers=1) as executor: + consumer_future = executor.submit(consume) + assert consuming_event.wait(timeout=1.0) + subscription.close() + assert consumer_future.result(timeout=2.0) == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None: + """A producer should publish a message that a live subscription can consume.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + producer = topic.as_producer() + subscription = topic.subscribe() + message = b"hello streams" + + try: + self._start_subscription(subscription) + producer.publish(message) + + assert self._receive_message(subscription) == message + assert subscription.receive(timeout=0.1) is None + finally: + subscription.close() + + def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None: + """Each active subscription should receive the same newly published message.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscriptions = [topic.subscribe() for _ in range(3)] + new_message = b"message-visible-to-every-subscriber" + + try: + for subscription in subscriptions: + self._start_subscription(subscription) + + topic.publish(new_message) + + for subscription in subscriptions: + assert self._receive_message(subscription) == new_message + assert subscription.receive(timeout=0.1) is None + finally: + for subscription in subscriptions: + subscription.close() + + def test_each_subscription_only_receives_messages_published_after_it_starts( + self, + broadcast_channel: BroadcastChannel, + ) -> None: + """A late subscription should not replay messages that existed before its listener started.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + first_subscription = topic.subscribe() + second_subscription = topic.subscribe() + message_before_any_subscription = b"before-any-subscription" + message_after_first_subscription = b"after-first-subscription" + message_after_second_subscription = b"after-second-subscription" + + try: + topic.publish(message_before_any_subscription) + + self._start_subscription(first_subscription) + topic.publish(message_after_first_subscription) + + assert self._receive_message(first_subscription) == message_after_first_subscription + assert first_subscription.receive(timeout=0.1) is None + + self._start_subscription(second_subscription) + topic.publish(message_after_second_subscription) + + assert self._receive_message(first_subscription) == message_after_second_subscription + assert self._receive_message(second_subscription) == message_after_second_subscription + assert first_subscription.receive(timeout=0.1) is None + assert second_subscription.receive(timeout=0.1) is None + finally: + first_subscription.close() + second_subscription.close() + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None: + """Messages from different topics should remain isolated.""" + topic1 = broadcast_channel.topic(self._get_test_topic_name()) + topic2 = broadcast_channel.topic(self._get_test_topic_name()) + message1 = b"message-for-topic-1" + message2 = b"message-for-topic-2" + + def consume_single_message(topic: Topic) -> bytes: + subscription = topic.subscribe() + try: + self._start_subscription(subscription) + return self._receive_message(subscription) + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + consumer1_future = executor.submit(consume_single_message, topic1) + consumer2_future = executor.submit(consume_single_message, topic2) + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + assert consumer1_future.result(timeout=5.0) == message1 + assert consumer2_future.result(timeout=5.0) == message2 + + def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None: + """Concurrent producers should not lose messages for a live subscription.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + producer_count = 4 + messages_per_producer = 4 + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def produce_messages(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced: set[bytes] = set() + for message_idx in range(messages_per_producer): + payload = f"producer-{producer_idx}-message-{message_idx}".encode() + produced.add(payload) + producer.publish(payload) + time.sleep(0.001) + return produced + + def consume_messages() -> set[bytes]: + received: set[bytes] = set() + try: + self._start_subscription(subscription) + consumer_ready.set() + while len(received) < expected_total: + message = subscription.receive(timeout=0.2) + if message is not None: + received.add(message) + return received + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consume_messages) + assert consumer_ready.wait(timeout=2.0) + + producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)] + expected_messages: set[bytes] = set() + for future in as_completed(producer_futures, timeout=10.0): + expected_messages.update(future.result()) + + assert consumer_future.result(timeout=10.0) == expected_messages + + def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None: + """Calling receive on a closed subscription should raise SubscriptionClosedError.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + + self._start_subscription(subscription) + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.1) diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py deleted file mode 100644 index c9058626d1..0000000000 --- a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.orm import sessionmaker - -from extensions.ext_database import db -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) - ) - - results = repository.get_by_message_ids([fixture.message.id]) - - assert len(results) == 1 - assert len(results[0]) == 1 - content = results[0][0] - assert content.submitted is True - assert content.form_submission_data is not None - assert content.form_submission_data.action_id == fixture.action_id - assert content.form_submission_data.action_text == fixture.action_text - assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 76e586e65f..49b370990a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,6 +2,7 @@ from __future__ import annotations +import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker from dify_graph.entities import WorkflowExecution -from dify_graph.entities.pause_reason import PauseReasonType +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from models.human_input import ( + BackstageRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) +from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, + _build_human_input_required_reason, + _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: WorkflowRun.app_id == scope.app_id, ) ) + + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(HumanInputForm).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + ) session.commit() for state_key in scope.state_keys: @@ -193,7 +218,7 @@ class TestDeleteRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -253,7 +278,7 @@ class TestCountRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -504,3 +529,200 @@ class TestDeleteWorkflowPause: with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"): repository.delete_workflow_pause(pause_entity=pause_entity) + + +class TestPrivateWorkflowPauseEntity: + """Integration tests for _PrivateWorkflowPauseEntity using real DB models.""" + + def test_properties( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Entity properties delegate to the persisted WorkflowPause model.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(pause.state_object_key) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + + assert entity.id == pause.id + assert entity.workflow_execution_id == workflow_run.id + assert entity.resumed_at is None + + def test_get_state( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state loads state data from storage using the persisted state_object_key.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result = entity.get_state() + + assert result == expected_state + + def test_get_state_caching( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state caches the result so storage is only accessed once.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result1 = entity.get_state() + # Delete from storage to prove second call uses cache + storage.delete(state_key) + test_scope.state_keys.discard(state_key) + result2 = entity.get_state() + + assert result1 == expected_state + assert result2 == expected_state + + +class TestBuildHumanInputRequiredReason: + """Integration tests for _build_human_input_required_reason using real DB models.""" + + def test_prefers_backstage_token_when_available( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Use backstage token when multiple recipient types may exist.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + access_token = secrets.token_urlsafe(8) + recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=access_token, + ) + db_session_with_containers.add(recipient) + db_session_with_containers.flush() + + # Create a pause so the reason has a valid pause_id + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + # Refresh to ensure we have DB-round-tripped objects + db_session_with_containers.refresh(form_model) + db_session_with_containers.refresh(reason_model) + db_session_with_containers.refresh(recipient) + + reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + + assert isinstance(reason, HumanInputRequired) + assert reason.form_token == access_token + assert reason.node_title == "Ask Name" + assert reason.form_content == "content" + assert reason.inputs[0].output_variable_name == "name" + assert reason.actions[0].id == "approve" diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 0000000000..ed998c9ed0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,407 @@ +"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers. + +Part of #32454 — replaces the mock-based unit tests with real database interactions. +""" + +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from dify_graph.nodes.human_input.entities import FormDefinition, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource, InvokeFrom +from models.execution_extra_content import ExecutionExtraContent, HumanInputContent +from models.human_input import ( + ConsoleRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) +from models.model import App, Conversation, Message +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows. + + IDs are populated after flushing the base entities to the database. + """ + + tenant_id: str = "" + app_id: str = "" + user_id: str = "" + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(ExecutionExtraContent).where( + ExecutionExtraContent.workflow_run_id.in_( + select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id) + ) + ) + ) + session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id)) + session.execute(delete(Message).where(Message.app_id == scope.app_id)) + session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id)) + session.execute(delete(App).where(App.id == scope.app_id)) + session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id)) + session.execute(delete(Account).where(Account.id == scope.user_id)) + session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id)) + session.commit() + + +def _seed_base_entities(session: Session, scope: _TestScope) -> None: + """Create the base tenant, account, and app needed by tests.""" + tenant = Tenant(name="Test Tenant") + session.add(tenant) + session.flush() + scope.tenant_id = tenant.id + + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + scope.user_id = account.id + + tenant_join = TenantAccountJoin( + tenant_id=scope.tenant_id, + account_id=scope.user_id, + role=TenantAccountRole.OWNER, + current=True, + ) + session.add(tenant_join) + + app = App( + tenant_id=scope.tenant_id, + name="Test App", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=scope.user_id, + updated_by=scope.user_id, + ) + session.add(app) + session.flush() + scope.app_id = app.id + + +def _create_conversation(session: Session, scope: _TestScope) -> Conversation: + conversation = Conversation( + app_id=scope.app_id, + mode="chat", + name="Test Conversation", + summary="", + introduction="", + system_instruction="", + status="normal", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + from_end_user_id=None, + ) + conversation.inputs = {} + session.add(conversation) + session.flush() + return conversation + + +def _create_message( + session: Session, + scope: _TestScope, + conversation_id: str, + workflow_run_id: str, +) -> Message: + message = Message( + app_id=scope.app_id, + conversation_id=conversation_id, + inputs={}, + query="test query", + message={"messages": []}, + answer="test answer", + message_tokens=50, + message_unit_price=Decimal("0.001"), + answer_tokens=80, + answer_unit_price=Decimal("0.001"), + provider_response_latency=0.5, + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + workflow_run_id=workflow_run_id, + ) + session.add(message) + session.flush() + return message + + +def _create_submitted_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + action_id: str = "approve", + action_title: str = "Approve", + node_title: str = "Approval", +) -> HumanInputForm: + expiration_time = datetime.utcnow() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_title)], + rendered_content="rendered", + expiration_time=expiration_time, + node_title=node_title, + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content=f"Rendered {action_title}", + status=HumanInputFormStatus.SUBMITTED, + expiration_time=expiration_time, + selected_action_id=action_id, + ) + session.add(form) + session.flush() + return form + + +def _create_waiting_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + default_values: dict | None = None, +) -> HumanInputForm: + expiration_time = datetime.utcnow() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values=default_values or {"name": "John"}, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + session.add(form) + session.flush() + return form + + +def _create_human_input_content( + session: Session, + *, + workflow_run_id: str, + message_id: str, + form_id: str, +) -> HumanInputContent: + content = HumanInputContent.new( + workflow_run_id=workflow_run_id, + message_id=message_id, + form_id=form_id, + ) + session.add(content) + return content + + +def _create_recipient( + session: Session, + *, + form_id: str, + delivery_id: str, + recipient_type: RecipientType = RecipientType.CONSOLE, + access_token: str = "token-1", +) -> HumanInputFormRecipient: + payload = ConsoleRecipientPayload(account_id=None) + recipient = HumanInputFormRecipient( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=recipient_type, + recipient_payload=payload.model_dump_json(), + access_token=access_token, + ) + session.add(recipient) + return recipient + + +def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: + from dify_graph.nodes.human_input.enums import DeliveryMethodType + from models.human_input import ConsoleDeliveryPayload + + delivery = HumanInputDelivery( + form_id=form_id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload=ConsoleDeliveryPayload().model_dump_json(), + ) + session.add(delivery) + session.flush() + return delivery + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + _seed_base_entities(db_session_with_containers, scope) + db_session_with_containers.commit() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetByMessageIds: + """Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids.""" + + def test_groups_contents_by_message( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Submitted forms are correctly mapped and grouped by message ID.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_submitted_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + action_id="approve", + action_title="Approve", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg1.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg1.id, msg2.id]) + + assert len(result) == 2 + # msg1 has one submitted content + assert len(result[0]) == 1 + content = result[0][0] + assert content.submitted is True + assert content.workflow_run_id == workflow_run_id + assert content.form_submission_data is not None + assert content.form_submission_data.action_id == "approve" + assert content.form_submission_data.action_text == "Approve" + assert content.form_submission_data.rendered_content == "Rendered Approve" + assert content.form_submission_data.node_id == "node-id" + assert content.form_submission_data.node_title == "Approval" + # msg2 has no content + assert result[1] == [] + + def test_returns_unsubmitted_form_definition( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Waiting forms return full form_definition with resolved token and defaults.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_waiting_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + default_values={"name": "John"}, + ) + delivery = _create_delivery(db_session_with_containers, form_id=form.id) + _create_recipient( + db_session_with_containers, + form_id=form.id, + delivery_id=delivery.id, + access_token="token-1", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg.id]) + + assert len(result) == 1 + assert len(result[0]) == 1 + domain_content = result[0][0] + assert domain_content.submitted is False + assert domain_content.workflow_run_id == workflow_run_id + assert domain_content.form_definition is not None + form_def = domain_content.form_definition + assert form_def.form_id == form.id + assert form_def.node_id == "node-id" + assert form_def.node_title == "Approval" + assert form_def.form_content == "Rendered block" + assert form_def.display_in_ui is True + assert form_def.form_token == "token-1" + assert form_def.resolved_default_values == {"name": "John"} + assert form_def.expiration_time == int(form.expiration_time.timestamp()) + + def test_empty_message_ids_returns_empty_list( + self, + repository: SQLAlchemyExecutionExtraContentRepository, + ) -> None: + """Passing no message IDs returns an empty list without hitting the DB.""" + result = repository.get_by_message_ids([]) + assert result == [] diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py new file mode 100644 index 0000000000..1568d5d65c --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -0,0 +1,391 @@ +"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import timedelta +from uuid import uuid4 + +import pytest +from sqlalchemy import Engine, delete +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import Session, sessionmaker + +from dify_graph.entities import WorkflowExecution +from dify_graph.enums import WorkflowExecutionStatus +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType +from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository + + +class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): + """Concrete repository for tests where save() is not under test.""" + + def save(self, execution: WorkflowExecution) -> None: + return None + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows.""" + + tenant_id: str = field(default_factory=lambda: str(uuid4())) + app_id: str = field(default_factory=lambda: str(uuid4())) + workflow_id: str = field(default_factory=lambda: str(uuid4())) + user_id: str = field(default_factory=lambda: str(uuid4())) + + +def _create_workflow_run( + session: Session, + scope: _TestScope, + *, + status: WorkflowExecutionStatus, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + created_at_offset: timedelta | None = None, +) -> WorkflowRun: + """Create and persist a workflow run bound to the current test scope.""" + now = naive_utc_now() + workflow_run = WorkflowRun( + id=str(uuid4()), + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_id=scope.workflow_id, + type=WorkflowType.WORKFLOW, + triggered_from=triggered_from, + version="draft", + graph="{}", + inputs="{}", + status=status, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=scope.user_id, + created_at=now + created_at_offset if created_at_offset is not None else now, + ) + session.add(workflow_run) + session.commit() + return workflow_run + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + session.execute( + delete(WorkflowRun).where( + WorkflowRun.tenant_id == scope.tenant_id, + WorkflowRun.app_id == scope.app_id, + ) + ) + session.commit() + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> _TestScope: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetPaginatedWorkflowRuns: + """Integration tests for get_paginated_workflow_runs.""" + + def test_returns_runs_without_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return all runs for the given tenant/app when no status filter is applied.""" + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.RUNNING, + ): + _create_workflow_run(db_session_with_containers, test_scope, status=status) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + assert len(result.data) == 3 + assert result.limit == 20 + assert result.has_more is False + + def test_filters_by_status( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return only runs matching the requested status.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status="succeeded", + ) + + assert len(result.data) == 2 + assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data) + + def test_pagination_has_more( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return has_more=True when more records exist beyond the limit.""" + for i in range(5): + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(seconds=i), + ) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=None, + status=None, + ) + + assert len(result.data) == 3 + assert result.has_more is True + + def test_cursor_based_pagination( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Cursor-based pagination returns the next page of results.""" + for i in range(5): + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(seconds=i), + ) + + # First page + page1 = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=None, + status=None, + ) + assert len(page1.data) == 3 + assert page1.has_more is True + + # Second page using cursor + page2 = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=page1.data[-1].id, + status=None, + ) + assert len(page2.data) == 2 + assert page2.has_more is False + + # No overlap between pages + page1_ids = {r.id for r in page1.data} + page2_ids = {r.id for r in page2.data} + assert page1_ids.isdisjoint(page2_ids) + + def test_invalid_last_id_raises( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + test_scope: _TestScope, + ) -> None: + """Raise ValueError when last_id refers to a non-existent run.""" + with pytest.raises(ValueError, match="Last workflow run not exists"): + repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=str(uuid4()), + status=None, + ) + + def test_tenant_isolation( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Runs from other tenants are not returned.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + other_scope = _TestScope(app_id=test_scope.app_id) + try: + _create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + assert len(result.data) == 1 + assert result.data[0].tenant_id == test_scope.tenant_id + finally: + _cleanup_scope_data(db_session_with_containers, other_scope) + + +class TestGetWorkflowRunsCount: + """Integration tests for get_workflow_runs_count.""" + + def test_count_without_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count all runs grouped by status when no status filter is applied.""" + for _ in range(3): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + for _ in range(2): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + ) + + assert result["total"] == 6 + assert result["succeeded"] == 3 + assert result["failed"] == 2 + assert result["running"] == 1 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_count_with_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count only runs matching the requested status.""" + for _ in range(3): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + ) + + assert result["total"] == 3 + assert result["succeeded"] == 3 + assert result["failed"] == 0 + + def test_count_with_invalid_status_raises( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Invalid status raises StatementError because the column uses an enum type.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + with pytest.raises(sa_exc.StatementError) as exc_info: + repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="invalid_status", + ) + assert isinstance(exc_info.value.orig, ValueError) + + def test_count_with_time_range( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Time range filter excludes runs created outside the window.""" + # Recent run (within 1 day) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + # Old run (8 days ago) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(days=-8), + ) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + time_range="7d", + ) + + assert result["total"] == 1 + assert result["succeeded"] == 1 + + def test_count_with_status_and_time_range( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Both status and time_range filters apply together.""" + # Recent succeeded + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + # Recent failed + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + # Old succeeded (outside time range) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(days=-8), + ) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + time_range="7d", + ) + + assert result["total"] == 1 + assert result["succeeded"] == 1 + assert result["failed"] == 0 diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d7..02c3d1a80e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index 251f17dd03..42d587b7f7 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,6 +13,8 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus @@ -90,7 +92,7 @@ class DocumentStatusTestDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id document.indexing_status = indexing_status @@ -198,7 +200,7 @@ class DocumentStatusTestDataFactory: """ upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"uploads/{uuid4()}", name=name, size=128, diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4759d244fd..b51fbc3a42 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account +from models.enums import ConversationFromSource, MessageFileBelongsTo from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService @@ -164,7 +165,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(conversation) db_session_with_containers.commit() @@ -203,7 +204,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -405,7 +406,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(conversation) db_session_with_containers.commit() @@ -444,7 +445,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -477,7 +478,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(conversation) db_session_with_containers.commit() @@ -516,7 +517,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -623,7 +624,7 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, app_model_config_id=None, # Explicitly set to None ) db_session_with_containers.add(conversation) @@ -646,7 +647,7 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) db_session_with_containers.add(message) db_session_with_containers.commit() @@ -852,7 +853,7 @@ class TestAgentService: type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/file1.jpg", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) @@ -861,7 +862,7 @@ class TestAgentService: type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/file2.png", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index a260d823a2..95fc73f45a 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account +from models.enums import ConversationFromSource, InvokeFrom from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService from services.app_service import AppService @@ -136,8 +137,8 @@ class TestAnnotationService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -174,8 +175,8 @@ class TestAnnotationService: provider_response_latency=0, total_price=0, currency="USD", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -721,7 +722,7 @@ class TestAnnotationService: query=f"Query {i}: {fake.sentence()}", user_id=account.id, message_id=fake.uuid4(), - from_source="console", + from_source=ConversationFromSource.CONSOLE, score=0.8 + (i * 0.1), ) @@ -772,7 +773,7 @@ class TestAnnotationService: query=query, user_id=account.id, message_id=message_id, - from_source="console", + from_source=ConversationFromSource.CONSOLE, score=score, ) diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 7ce7357b41..b8e022503f 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -525,3 +525,147 @@ class TestAPIBasedExtensionService: # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + + def test_save_extension_api_key_exactly_four_chars_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 4 characters should be rejected (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="1234", + ) + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_api_key_exactly_five_chars_accepted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 5 characters should be accepted (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="12345", + ) + + saved = APIBasedExtensionService.save(extension_data) + assert saved.id is not None + + def test_save_extension_requestor_constructor_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Exception raised by requestor constructor is wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config") + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_network_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Network exceptions during ping are wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError( + "network failure" + ) + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_update_duplicate_name_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Updating an existing extension to use another extension's name should fail.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + ext1 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Alpha", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + ext2 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Beta", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + # Try to rename ext2 to ext1's name + ext2.name = "Extension Alpha" + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(ext2) + + def test_get_all_returns_empty_for_different_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Extensions from one tenant should not be visible to another.""" + fake = Faker() + _, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + _, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant1 is not None + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant1.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + assert tenant2 is not None + result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + assert result == [] diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..a83af30fb9 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. @@ -1142,3 +1245,51 @@ class TestAppService: assert paginated_apps is not None assert paginated_apps.total == 1 assert all("50%" in app.name for app in paginated_apps.items) + + def test_get_app_code_by_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_code_by_id raises ValueError when site is missing.""" + from uuid import uuid4 + + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id(str(uuid4())) + + def test_get_app_id_by_code_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_id_by_code raises ValueError when code does not exist.""" + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("nonexistent-code") + + def test_get_app_meta_returns_empty_when_workflow_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when workflow is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + workflow_app = SimpleNamespace(mode="workflow", workflow=None) + + meta = app_service.get_app_meta(workflow_app) + assert meta == {"tool_icons": {}} + + def test_get_app_meta_returns_empty_when_model_config_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when app_model_config is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + chat_app = SimpleNamespace(mode="chat", app_model_config=None) + + meta = app_service.get_app_meta(chat_app) + assert meta == {"tool_icons": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py new file mode 100644 index 0000000000..768a8baee2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -0,0 +1,80 @@ +"""Testcontainers integration tests for AttachmentService.""" + +import base64 +from datetime import UTC, datetime +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +import services.attachment_service as attachment_service_module +from extensions.ext_database import db +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.attachment_service import AttachmentService + + +class TestAttachmentService: + def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id or str(uuid4()), + storage_type=StorageType.OPENDAL, + key=f"upload/{uuid4()}.txt", + name="test-file.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + def test_should_initialize_with_sessionmaker(self): + session_factory = sessionmaker() + + service = AttachmentService(session_factory=session_factory) + + assert service._session_maker is session_factory + + def test_should_initialize_with_engine(self): + engine = create_engine("sqlite:///:memory:") + + service = AttachmentService(session_factory=engine) + session = service._session_maker() + try: + assert session.bind == engine + finally: + session.close() + engine.dispose() + + @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) + def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + AttachmentService(session_factory=invalid_session_factory) + + def test_should_return_base64_when_file_exists(self, db_session_with_containers): + upload_file = self._create_upload_file(db_session_with_containers) + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: + result = service.get_file_base64(upload_file.id) + + assert result == base64.b64encode(b"binary-content").decode() + mock_load.assert_called_once_with(upload_file.key) + + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once") as mock_load: + with pytest.raises(NotFound, match="File not found"): + service.get_file_base64(str(uuid4())) + + mock_load.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 5f64e6f674..6180d98b1e 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -10,6 +10,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.conversation_service import ConversationService @@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory: system_instruction_tokens=0, status="normal", invoke_from=invoke_from.value, - from_source="api" if isinstance(user, EndUser) else "console", + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, dialogue_count=0, @@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory: currency="USD", status="normal", invoke_from=InvokeFrom.WEB_APP.value, - from_source="api" if isinstance(user, EndUser) else "console", + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, ) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py new file mode 100644 index 0000000000..42a2215896 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -0,0 +1,58 @@ +"""Testcontainers integration tests for ConversationVariableUpdater.""" + +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import sessionmaker + +from dify_graph.variables import StringVariable +from extensions.ext_database import db +from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater + + +class TestConversationVariableUpdater: + def _create_conversation_variable( + self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + ) -> ConversationVariable: + row = ConversationVariable( + id=variable.id, + conversation_id=conversation_id, + app_id=app_id or str(uuid4()), + data=variable.model_dump_json(), + ) + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="old value") + self._create_conversation_variable( + db_session_with_containers, conversation_id=conversation_id, variable=variable + ) + + updated_variable = StringVariable(id=variable.id, name="topic", value="new value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + updater.update(conversation_id=conversation_id, variable=updated_variable) + + db_session_with_containers.expire_all() + row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id)) + assert row is not None + assert row.data == updated_variable.model_dump_json() + + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): + updater.update(conversation_id=conversation_id, variable=variable) + + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + result = updater.flush() + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..0f63d98642 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -0,0 +1,104 @@ +"""Testcontainers integration tests for CreditPoolService.""" + +from uuid import uuid4 + +import pytest + +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from models.enums import ProviderQuotaType +from services.credit_pool_service import CreditPoolService + + +class TestCreditPoolService: + def _create_tenant_id(self) -> str: + return str(uuid4()) + + def test_create_default_pool(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + + pool = CreditPoolService.create_default_pool(tenant_id) + + assert isinstance(pool, TenantCreditPool) + assert pool.tenant_id == tenant_id + assert pool.pool_type == ProviderQuotaType.TRIAL + assert pool.quota_used == 0 + assert pool.quota_limit > 0 + + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) + + assert result is not None + assert result.tenant_id == tenant_id + assert result.pool_type == ProviderQuotaType.TRIAL + + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) + + assert result is None + + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) + + assert result is False + + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10) + + assert result is True + + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + # Exhaust credits + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1) + + assert result is False + + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + with pytest.raises(QuotaExceededError, match="Credit pool not found"): + CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) + + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + with pytest.raises(QuotaExceededError, match="No credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) + + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + credits_required = 10 + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required) + + assert result == credits_required + db_session_with_containers.expire_all() + pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert pool.quota_used == credits_required + + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + remaining = 5 + pool.quota_used = pool.quota_limit - remaining + db_session_with_containers.commit() + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + + assert result == remaining + db_session_with_containers.expire_all() + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool.quota_used == pool.quota_limit diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 975af3d428..71c8874f79 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", @@ -397,6 +398,68 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + """Test that users from different tenants cannot access dataset.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other_user) + + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + """Test that tenant owners can access any dataset regardless of permission level.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, owner) + + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + """Test ONLY_ME permission allows only the dataset creator to access.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + """Test ONLY_ME permission denies access to non-creators.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other) + + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + """Test ALL_TEAM permission allows any team member to access the dataset.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, member) + def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): """ Test that user with explicit permission can access partial_members dataset. @@ -443,6 +506,16 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, creator) + class TestDatasetServiceCheckDatasetOperatorPermission: """Verify operator permission checks against persisted partial-member permissions.""" diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604..0702680f5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -62,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory: created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -156,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -180,13 +181,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -272,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -305,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -313,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -588,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -684,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, @@ -706,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] + + +class TestDocumentServicePauseRecoverRetry: + """Tests for pause/recover/retry orchestration using real DB and Redis.""" + + def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc = factory.create_document(db_session_with_containers, dataset, account.id) + doc.indexing_status = indexing_status + db_session_with_containers.commit() + return doc, account + + def test_pause_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is True + assert doc.paused_by == account.id + assert doc.paused_at is not None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is not None + redis_client.delete(cache_key) + + def test_pause_document_invalid_status_error(self, db_session_with_containers): + from services.dataset_service import DocumentService + from services.errors.document import DocumentIndexingError + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(doc) + + def test_recover_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + # Pause first + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + # Recover + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is False + assert doc.paused_by is None + assert doc.paused_at is None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is None + recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) + + def test_retry_document_indexing_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt") + doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt") + doc2.position = 2 + doc1.indexing_status = "error" + doc2.indexing_status = "error" + db_session_with_containers.commit() + + with ( + patch("services.dataset_service.current_user") as mock_user, + patch("services.dataset_service.retry_document_indexing_task") as retry_task, + ): + mock_user.id = account.id + DocumentService.retry_document(dataset.id, [doc1, doc2]) + + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + assert doc1.indexing_status == "waiting" + assert doc2.indexing_status == "waiting" + + # Verify redis keys were set + assert redis_client.get(f"document_{doc1.id}_is_retried") is not None + assert redis_client.get(f"document_{doc2.id}_is_retried") is not None + retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id) + + # Cleanup + redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried") diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 7983b1cd93..c1d088755c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id or str(uuid4()) document.enabled = enabled @@ -694,3 +695,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) + + def test_batch_update_invalid_action_raises_value_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Test that an invalid action raises ValueError.""" + factory = DocumentBatchUpdateIntegrationDataFactory + dataset = factory.create_dataset(db_session_with_containers) + doc = factory.create_document(db_session_with_containers, dataset) + user = UserDouble(id=str(uuid4())) + + patched_dependencies["redis_client"].get.return_value = None + + with pytest.raises(ValueError, match="Invalid action"): + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py new file mode 100644 index 0000000000..c486ff5613 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -0,0 +1,60 @@ +"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.account import Account, Tenant, TenantAccountJoin +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity + + +class TestDatasetServiceCreateRagPipelineDataset: + def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"ds_create_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return tenant, account + + def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity: + icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji") + return RagPipelineDatasetCreateEntity( + name=name, + description="", + icon_info=icon_info, + permission="only_me", + ) + + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + tenant, _ = self._create_tenant_and_account(db_session_with_containers) + + mock_user = Mock(id=None) + with patch("services.dataset_service.current_user", mock_user): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=self._build_entity(), + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index ed070527c9..3cac964d89 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory: tenant_id: str, dataset_id: str, created_by: str, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """Persist a document so dataset.doc_form resolves through the real document path.""" document = Document( @@ -108,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset: tenant_id=tenant.id, dataset_id=dataset.id, created_by=owner.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Act @@ -207,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2..87239b2cb3 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d..2f90d16176 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..2899d5b8a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -4,6 +4,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 5f86cb2ae9..376a89d1ce 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None + + def test_delete_run_dry_run(self, db_session_with_containers): + """Dry run should return success without actually deleting.""" + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + run_id = run.id + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + + result = deleter._delete_run(run) + + assert result.success is True + assert result.run_id == run_id + # Run should still exist because it's a dry run + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is not None + + def test_delete_run_exception_returns_error(self, db_session_with_containers): + """Exception during deletion should return failure result.""" + from unittest.mock import MagicMock, patch + + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + deleter = ArchivedWorkflowRunDeletion(dry_run=False) + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deleter._delete_run(run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_by_run_id_success(self, db_session_with_containers): + """Successfully delete an archived workflow run by ID.""" + tenant_id = str(uuid4()) + base_time = datetime.now(UTC) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=base_time, + ) + self._create_archive_log(db_session_with_containers, run=run) + run_id = run.id + + deleter = ArchivedWorkflowRunDeletion() + result = deleter.delete_by_run_id(run_id) + + assert result.success is True + db_session_with_containers.expunge_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is None + + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + """_get_workflow_run_repo should return a cached repo on subsequent calls.""" + deleter = ArchivedWorkflowRunDeletion() + + repo1 = deleter._get_workflow_run_repo() + repo2 = deleter._get_workflow_run_repo() + + assert repo1 is repo2 + assert deleter.workflow_run_repo is repo1 diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index c6aa89c733..c0047df810 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -3,6 +3,7 @@ from uuid import uuid4 from sqlalchemy import select +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -42,7 +43,7 @@ def _create_document( name=f"doc-{uuid4()}", created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = str(uuid4()) document.indexing_status = indexing_status @@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c rows = db_session_with_containers.scalars(filtered).all() assert {row.id for row in rows} == {doc1.id, doc2.id} + + +def test_normalize_display_status_alias_mapping(): + """Test that normalize_display_status maps aliases correctly.""" + assert DocumentService.normalize_display_status("ACTIVE") == "available" + assert DocumentService.normalize_display_status("enabled") == "available" + assert DocumentService.normalize_display_status("archived") == "archived" + assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index b159af0090..34532ed7f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -7,6 +7,8 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom @@ -68,7 +70,7 @@ def make_document( name=name, created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) doc.id = document_id doc.indexing_status = "completed" @@ -83,7 +85,7 @@ def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, n """Persist an upload file row referenced by document.data_source_info.""" upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"uploads/{uuid4()}", name=name, size=128, diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index ae811db768..cafabc939b 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById: ) assert result is None + + +class TestEndUserServiceCreateBatch: + """Integration tests for EndUserService.create_end_user_batch.""" + + @pytest.fixture + def factory(self): + return TestEndUserServiceFactory() + + def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + """Create multiple apps under the same tenant.""" + first_app = factory.create_app_and_account(db_session_with_containers) + tenant_id = first_app.tenant_id + apps = [first_app] + for _ in range(count - 1): + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=first_app.created_by, + updated_by=first_app.updated_by, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() + return tenant_id, all_apps + + def test_create_batch_empty_app_ids(self, db_session_with_containers): + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" + ) + assert result == {} + + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 3 + for app_id in app_ids: + assert app_id in result + assert result[app_id].session_id == user_id + assert result[app_id].type == InvokeFrom.SERVICE_API + + def test_create_batch_default_session_id(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="" + ) + + assert len(result) == 2 + for end_user in result.values(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 2 + + def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + # Create batch first time + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Create batch second time — should return existing users + second_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(second_result) == 2 + for app_id in app_ids: + assert first_result[app_id].id == second_result[app_id].id + + def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + user_id = f"user-{uuid4()}" + + # Create for first 2 apps + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[apps[0].id, apps[1].id], + user_id=user_id, + ) + + # Create for all 3 apps — should reuse first 2, create 3rd + all_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[a.id for a in apps], + user_id=user_id, + ) + + assert len(all_result) == 3 + assert all_result[apps[0].id].id == first_result[apps[0].id].id + assert all_result[apps[1].id].id == first_result[apps[1].id].id + assert all_result[apps[2].id].session_id == user_id + + @pytest.mark.parametrize( + "invoke_type", + [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], + ) + def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id + ) + + assert len(result) == 1 + assert result[apps[0].id].type == invoke_type diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index 60919dff0d..771f406775 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -8,6 +8,7 @@ from unittest import mock import pytest from extensions.ext_database import db +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, Conversation, Message from services.feedback_service import FeedbackService @@ -47,8 +48,8 @@ class TestFeedbackService: app_id=app_id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content="Great answer!", from_end_user_id="user-123", from_account_id=None, @@ -61,8 +62,8 @@ class TestFeedbackService: app_id=app_id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="dislike", - from_source="admin", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.ADMIN, content="Could be more detailed", from_end_user_id=None, from_account_id="admin-456", @@ -179,8 +180,8 @@ class TestFeedbackService: # Test with filters result = FeedbackService.export_feedbacks( app_id=sample_data["app"].id, - from_source="admin", - rating="dislike", + from_source=FeedbackFromSource.ADMIN, + rating=FeedbackRating.DISLIKE, has_comment=True, start_date="2024-01-01", end_date="2024-12-31", @@ -293,8 +294,8 @@ class TestFeedbackService: app_id=sample_data["app"].id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="dislike", - from_source="user", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.USER, content="回答不够详细,需要更多信息", from_end_user_id="user-123", from_account_id=None, diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 50f5b7a8c0..42dbdef1c9 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config +from extensions.storage.storage_type import StorageType from models import Account, Tenant from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -140,7 +141,7 @@ class TestFileService: upload_file = UploadFile( tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()), - storage_type="local", + storage_type=StorageType.LOCAL, key=f"upload_files/test/{fake.uuid4()}.txt", name="test_file.txt", size=1024, diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py new file mode 100644 index 0000000000..4e0a726cc7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -0,0 +1,96 @@ +""" +Testcontainers integration tests for FileService helpers. + +Covers: +- ZIP tempfile building (sanitization + deduplication + content writes) +- tenant-scoped batch lookup behavior (get_upload_files_by_ids) +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 +from zipfile import ZipFile + +import pytest + +import services.file_service as file_service_module +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.file_service import FileService + + +def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.OPENDAL, + key=key, + name=name, + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session.add(upload_file) + db_session.commit() + return upload_file + + +def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure ZIP entry names are safe and unique while preserving extensions.""" + upload_files: list[Any] = [ + SimpleNamespace(name="a/b.txt", key="k1"), + SimpleNamespace(name="c/b.txt", key="k2"), + SimpleNamespace(name="../b.txt", key="k3"), + ] + + data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} + + def _load(key: str, stream: bool = True) -> list[bytes]: + assert stream is True + return data_by_key[key] + + monkeypatch.setattr(file_service_module.storage, "load", _load) + + with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: + with ZipFile(tmp, mode="r") as zf: + assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] + assert zf.read("b.txt") == b"one" + assert zf.read("b (1).txt") == b"two" + assert zf.read("b (2).txt") == b"three" + + +def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None: + """Ensure empty input returns an empty mapping without hitting the database.""" + assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + + +def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None: + """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" + tenant_id = str(uuid4()) + file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") + file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") + + result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + + assert set(result.keys()) == {file1.id, file2.id} + assert result[file1.id].id == file1.id + assert result[file2.id].id == file2.id + + +def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None: + """Ensure files from other tenants are not returned.""" + tenant_a = str(uuid4()) + tenant_b = str(uuid4()) + file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") + _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") + + result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + + assert set(result.keys()) == {file_a.id} diff --git a/api/tests/test_containers_integration_tests/services/test_message_export_service.py b/api/tests/test_containers_integration_tests/services/test_message_export_service.py index 200f688ae9..00dfe9dda4 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_export_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_export_service.py @@ -7,6 +7,7 @@ import pytest from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating from models.model import ( App, AppAnnotationHitHistory, @@ -93,7 +94,7 @@ class TestAppMessageExportServiceIntegration: name="conv", inputs={"seed": 1}, status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid.uuid4()), ) session.add(conversation) @@ -128,7 +129,7 @@ class TestAppMessageExportServiceIntegration: total_price=Decimal("0.003"), currency="USD", message_metadata=message_metadata, - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=conversation.from_end_user_id, created_at=created_at, ) @@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration: app_id=app.id, conversation_id=conversation.id, message_id=first_message.id, - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content="first", from_end_user_id=conversation.from_end_user_id, ) @@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration: app_id=app.id, conversation_id=conversation.id, message_id=first_message.id, - rating="dislike", - from_source="user", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.USER, content="second", from_end_user_id=conversation.from_end_user_id, ) @@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration: app_id=app.id, conversation_id=conversation.id, message_id=first_message.id, - rating="like", - from_source="admin", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, content="should-be-filtered", from_account_id=str(uuid.uuid4()), ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index a6d7bf27fd..85dc04b162 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom from models.model import MessageFeedback from services.app_service import AppService from services.errors.message import ( @@ -148,8 +149,8 @@ class TestMessageService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -186,8 +187,8 @@ class TestMessageService: provider_response_latency=0, total_price=0, currency="USD", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) @@ -405,7 +406,7 @@ class TestMessageService: message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create feedback - rating = "like" + rating = FeedbackRating.LIKE content = fake.text(max_nb_chars=100) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=rating, content=content @@ -435,7 +436,11 @@ class TestMessageService: # Test creating feedback with no user with pytest.raises(ValueError, match="user cannot be None"): MessageService.create_feedback( - app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) + app_model=app, + message_id=message.id, + user=None, + rating=FeedbackRating.LIKE, + content=fake.text(max_nb_chars=100), ) def test_create_feedback_update_existing( @@ -452,14 +457,14 @@ class TestMessageService: message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback - initial_rating = "like" + initial_rating = FeedbackRating.LIKE initial_content = fake.text(max_nb_chars=100) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content ) # Update feedback - updated_rating = "dislike" + updated_rating = FeedbackRating.DISLIKE updated_content = fake.text(max_nb_chars=100) updated_feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content @@ -487,7 +492,11 @@ class TestMessageService: # Create initial feedback feedback = MessageService.create_feedback( - app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100) + app_model=app, + message_id=message.id, + user=account, + rating=FeedbackRating.LIKE, + content=fake.text(max_nb_chars=100), ) # Delete feedback by setting rating to None @@ -538,7 +547,7 @@ class TestMessageService: app_model=app, message_id=message.id, user=account, - rating="like" if i % 2 == 0 else "dislike", + rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE, content=f"Feedback {i}: {fake.text(max_nb_chars=50)}", ) feedbacks.append(feedback) @@ -568,7 +577,11 @@ class TestMessageService: message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) MessageService.create_feedback( - app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" + app_model=app, + message_id=message.id, + user=account, + rating=FeedbackRating.LIKE, + content=f"Feedback {i}", ) # Get feedbacks with pagination diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py b/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py index 772365ba54..f2cb667204 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py @@ -4,6 +4,7 @@ from decimal import Decimal import pytest +from models.enums import ConversationFromSource from models.model import Message from services import message_service from tests.test_containers_integration_tests.helpers.execution_extra_content import ( @@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit total_price=Decimal(0), currency="USD", status="normal", - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_account_id=fixture.account.id, ) db_session_with_containers.add(message_without_extra_content) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index ef1f31d36b..57bbc73b50 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,10 +8,18 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from dify_graph.file.enums import FileType from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import DataSourceType +from models.enums import ( + ConversationFromSource, + DataSourceType, + FeedbackFromSource, + FeedbackRating, + MessageChainType, + MessageFileBelongsTo, +) from models.model import ( App, AppAnnotationHitHistory, @@ -166,7 +174,7 @@ class TestMessagesCleanServiceIntegration: name="Test conversation", inputs={}, status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid.uuid4()), ) db_session_with_containers.add(conversation) @@ -196,7 +204,7 @@ class TestMessagesCleanServiceIntegration: answer_unit_price=Decimal("0.002"), total_price=Decimal("0.003"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, from_account_id=conversation.from_end_user_id, created_at=created_at, ) @@ -216,8 +224,8 @@ class TestMessagesCleanServiceIntegration: app_id=message.app_id, conversation_id=message.conversation_id, message_id=message.id, - rating="like", - from_source="api", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, from_end_user_id=str(uuid.uuid4()), ) db_session_with_containers.add(feedback) @@ -236,7 +244,7 @@ class TestMessagesCleanServiceIntegration: # MessageChain chain = MessageChain( message_id=message.id, - type="system", + type=MessageChainType.SYSTEM, input=json.dumps({"test": "input"}), output=json.dumps({"test": "output"}), ) @@ -246,10 +254,10 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role="end_user", created_by=str(uuid.uuid4()), ) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e847329c5b..8b1349be9a 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -5,6 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom @@ -139,7 +140,7 @@ class TestMetadataService: name=fake.file_name(), created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..c146a5924b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index dd743d46c2..d256c0d90b 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage from services.app_service import AppService @@ -132,11 +133,14 @@ class TestSavedMessageService: # Create a simple conversation first from models.model import Conversation + is_account = hasattr(user, "current_tenant") + from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API + conversation = Conversation( app_id=app.id, - from_source="account" if hasattr(user, "current_tenant") else "end_user", - from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, - from_account_id=user.id if hasattr(user, "current_tenant") else None, + from_source=from_source, + from_end_user_id=user.id if not is_account else None, + from_account_id=user.id if is_account else None, name=fake.sentence(nb_words=3), inputs={}, status="normal", @@ -150,9 +154,9 @@ class TestSavedMessageService: message = Message( app_id=app.id, conversation_id=conversation.id, - from_source="account" if hasattr(user, "current_tenant") else "end_user", - from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, - from_account_id=user.id if hasattr(user, "current_tenant") else None, + from_source=from_source, + from_end_user_id=user.id if not is_account else None, + from_account_id=user.id if is_account else None, inputs={}, query=fake.sentence(nb_words=5), message=fake.text(max_nb_chars=100), @@ -392,11 +396,6 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) - # Verify no database operations were performed - - saved_messages = db_session_with_containers.query(SavedMessage).all() - assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -493,124 +492,140 @@ class TestSavedMessageService: # The message should still exist, only the saved_message should be deleted assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): - """ - Test error handling when no user is provided. - - This test verifies: - - Proper error handling for missing user - - ValueError is raised when user is None - - No database operations are performed - """ - # Arrange: Create test data - fake = Faker() + def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test saving a message for an EndUser.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) - # Act & Assert: Verify proper error handling - with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + mock_external_service_dependencies["message_service"].get_message.return_value = message - assert "User is required" in str(exc_info.value) + SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) - # Verify no database operations were performed for this specific test - # Note: We don't check total count as other tests may have created data - # Instead, we verify that the error was properly raised - pass - - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): - """ - Test error handling when saving message with no user. - - This test verifies: - - Method returns early when user is None - - No database operations are performed - - No exceptions are raised - """ - # Arrange: Create test data - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - message = self._create_test_message(db_session_with_containers, app, account) - - # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) - - # Assert: Verify the expected outcomes - assert result is None - - # Verify no saved message was created - - saved_message = ( + saved = ( db_session_with_containers.query(SavedMessage) - .where( - SavedMessage.app_id == app.id, - SavedMessage.message_id == message.id, - ) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) .first() ) + assert saved is not None + assert saved.created_by == end_user.id + assert saved.created_by_role == "end_user" - assert saved_message is None - - def test_delete_success_existing_message( + def test_save_duplicate_is_idempotent( self, db_session_with_containers: Session, mock_external_service_dependencies ): - """ - Test successful deletion of an existing saved message. - - This test verifies: - - Proper deletion of existing saved message - - Correct database state after deletion - - No errors during deletion process - """ - # Arrange: Create test data - fake = Faker() + """Test that saving an already-saved message does not create a duplicate.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) message = self._create_test_message(db_session_with_containers, app, account) - # Create a saved message first - saved_message = SavedMessage( - app_id=app.id, - message_id=message.id, - created_by_role="account", - created_by=account.id, - ) + mock_external_service_dependencies["message_service"].get_message.return_value = message - db_session_with_containers.add(saved_message) + # Save once + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + # Save again + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + + count = ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .count() + ) + assert count == 1 + + def test_delete_without_user_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting without a user is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Pre-create a saved message + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + db_session_with_containers.add(saved) db_session_with_containers.commit() - # Verify saved message exists + SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + + # Should still exist + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is not None + ) + + def test_delete_non_existent_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting a non-existent saved message is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Should not raise — use a valid UUID that doesn't exist in DB + from uuid import uuid4 + + SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + + def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test deleting a saved message for an EndUser.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) + + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + db_session_with_containers.add(saved) + db_session_with_containers.commit() + + SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is None + ) + + def test_delete_only_affects_own_saved_messages( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that delete only removes the requesting user's saved message.""" + app, account1 = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, account1) + + # Both users save the same message + saved_account = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + ) + saved_end_user = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + ) + db_session_with_containers.add_all([saved_account, saved_end_user]) + db_session_with_containers.commit() + + # Delete only account1's saved message + SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + + # Account's saved message should be gone assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == account1.id, ) .first() - is not None + is None ) - - # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) - - # Assert: Verify the expected outcomes - # Check if saved message was deleted from database - deleted_saved_message = ( + # End user's saved message should still exist + assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == end_user.id, ) .first() + is not None ) - - assert deleted_saved_message is None - - # Verify database state - db_session_with_containers.commit() - # The message should still exist, only the saved_message should be deleted - assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index fa6e651529..f504f35589 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,9 +7,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset -from models.enums import DataSourceType +from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) @@ -547,7 +548,7 @@ class TestTagService: assert result is not None assert len(result) == 1 assert result[0].name == "python_tag" - assert result[0].type == "app" + assert result[0].type == TagType.APP assert result[0].tenant_id == tenant.id def test_get_tag_by_tag_name_no_matches( @@ -638,7 +639,7 @@ class TestTagService: # Verify all tags are returned for tag in result: - assert tag.type == "app" + assert tag.type == TagType.APP assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 425611744b..6b95954480 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models import Account +from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation from services.account_service import AccountService, TenantService @@ -145,7 +146,7 @@ class TestWebConversationService: system_instruction_tokens=50, status="normal", invoke_from=InvokeFrom.WEB_APP, - from_source="console" if isinstance(user, Account) else "api", + from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, dialogue_count=0, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a..880143013e 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import json import uuid from datetime import UTC, datetime, timedelta +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -8,13 +11,14 @@ from faker import Faker from sqlalchemy.orm import Session from dify_graph.entities.workflow_execution import WorkflowExecutionStatus -from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun -from models.enums import CreatorUserRole +from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService -from services.workflow_app_service import WorkflowAppService +from services.workflow_app_service import LogView, WorkflowAppService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -221,7 +225,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -357,7 +361,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_1.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -399,7 +403,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_2.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -441,7 +445,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_4.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -521,7 +525,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -627,7 +631,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -732,7 +736,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -860,7 +864,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -902,7 +906,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="web-app", + created_from=WorkflowAppLogCreatedFrom.WEB_APP, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -1037,7 +1041,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1125,7 +1129,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1279,7 +1283,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1379,7 +1383,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1481,7 +1485,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1524,3 +1528,168 @@ class TestWorkflowAppService: # Should not find tenant2's data when searching from tenant1's context assert result_cross_tenant["total"] == 0 + + def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"): + service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + ) + + def test_get_paginate_workflow_app_logs_filters_by_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account) + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account=account.email, + ) + + assert result["total"] >= 0 + assert isinstance(result["data"], list) + + def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type="browser", + is_anonymous=False, + session_id="session-1", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + + now = datetime.now(UTC) + archive_defaults = { + "workflow_id": str(uuid.uuid4()), + "run_version": "1.0.0", + "run_status": WorkflowExecutionStatus.SUCCEEDED, + "run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN, + "run_error": None, + "run_elapsed_time": 1.0, + "run_total_tokens": 0, + "run_total_steps": 0, + "run_created_at": now, + "run_finished_at": now, + "run_exceptions_count": 0, + "trigger_metadata": '{"type":"trigger-webhook"}', + "log_created_at": now, + "log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API, + } + archive_account = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=account.id, + created_by_role=CreatorUserRole.ACCOUNT, + **archive_defaults, + ) + archive_end_user = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=end_user.id, + created_by_role=CreatorUserRole.END_USER, + **archive_defaults, + ) + db_session_with_containers.add_all([archive_account, archive_end_user]) + db_session_with_containers.commit() + + result = service.get_paginate_workflow_archive_logs( + session=db_session_with_containers, + app_model=app, + page=1, + limit=20, + ) + + assert result["total"] == 2 + assert len(result["data"]) == 2 + account_item = next(d for d in result["data"] if d["created_by_account"] is not None) + end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None) + assert account_item["created_by_account"].id == account.id + assert end_user_item["created_by_end_user"].id == end_user.id + + +class TestLogView: + def test_details_and_proxy_attributes(self): + log = SimpleNamespace(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + assert view.details == {"trigger_metadata": {"type": "plugin"}} + assert view.status == "succeeded" + + +class TestHandleTriggerMetadata: + def test_returns_empty_dict_when_metadata_missing(self): + service = WorkflowAppService() + assert service.handle_trigger_metadata("tenant-1", None) == {} + + def test_enriches_plugin_icons(self): + service = WorkflowAppService() + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + with patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + def test_non_plugin_metadata_without_icon_lookup(self): + service = WorkflowAppService() + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +class TestSafeJsonLoads: + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], + ) + def test_handles_various_inputs(self, value, expected): + assert WorkflowAppService._safe_json_loads(value) == expected + + +class TestSafeParseUuid: + def test_returns_none_for_short_or_invalid_values(self): + service = WorkflowAppService() + assert service._safe_parse_uuid("short") is None + assert service._safe_parse_uuid("x" * 40) is None + + def test_returns_uuid_for_valid_string(self): + service = WorkflowAppService() + raw = str(uuid.uuid4()) + result = service._safe_parse_uuid(raw) + assert result is not None + assert str(result) == raw diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index e080d6ef6b..731770e01a 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -7,7 +7,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole from models.model import ( Message, ) @@ -165,7 +165,7 @@ class TestWorkflowRunService: inputs={}, status="normal", mode="chat", - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, ) db_session_with_containers.add(conversation) @@ -186,7 +186,7 @@ class TestWorkflowRunService: message.answer_price_unit = 0.001 message.currency = "USD" message.status = "normal" - message.from_source = CreatorUserRole.ACCOUNT + message.from_source = ConversationFromSource.CONSOLE message.from_account_id = account.id message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 056db41750..a5fe052206 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -802,6 +802,81 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) + def test_restore_published_workflow_to_draft_does_not_persist_normalized_source_features( + self, db_session_with_containers: Session + ): + """Restore copies legacy feature JSON into draft without rewriting the source row.""" + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + app.mode = AppMode.ADVANCED_CHAT + + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + published_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version="2026.03.19.001", + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(legacy_features), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + draft_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + db_session_with_containers.add(published_workflow) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + restored_workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=published_workflow.id, + account=account, + ) + + db_session_with_containers.expire_all() + refreshed_published_workflow = ( + db_session_with_containers.query(Workflow).filter_by(id=published_workflow.id).first() + ) + refreshed_draft_workflow = db_session_with_containers.query(Workflow).filter_by(id=draft_workflow.id).first() + + assert restored_workflow.id == draft_workflow.id + assert refreshed_published_workflow is not None + assert refreshed_draft_workflow is not None + assert refreshed_published_workflow.serialized_features == json.dumps(legacy_features) + assert refreshed_draft_workflow.serialized_features == json.dumps(legacy_features) + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a..d3e765055a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index f3736333ea..2dc50cc720 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -1,12 +1,24 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest from faker import Faker from sqlalchemy.orm import Session -from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolDescription, + ToolEntity, + ToolIdentity, + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -48,41 +60,42 @@ class TestToolTransformService: name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', - icon_dark='{"background": "#252525", "content": "🔧"}', tenant_id="test_tenant_id", user_id="test_user_id", - credentials={"auth_type": "api_key_header", "api_key": "test_key"}, - provider_type="api", + credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', + schema="{}", + schema_type_str=ApiProviderSchemaType.OPENAPI, + tools_str="[]", ) elif provider_type == "builtin": provider = BuiltinToolProvider( name=fake.company(), - description=fake.text(max_nb_chars=100), - icon="🔧", - icon_dark="🔧", tenant_id="test_tenant_id", + user_id="test_user_id", provider="test_provider", credential_type="api_key", - credentials={"api_key": "test_key"}, + encrypted_credentials='{"api_key": "test_key"}', ) elif provider_type == "workflow": provider = WorkflowToolProvider( name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', - icon_dark='{"background": "#252525", "content": "🔧"}', tenant_id="test_tenant_id", user_id="test_user_id", - workflow_id="test_workflow_id", + app_id="test_workflow_id", + label="Test Workflow", + version="1.0.0", + parameter_configuration="[]", ) elif provider_type == "mcp": provider = MCPToolProvider( name=fake.company(), - description=fake.text(max_nb_chars=100), - provider_icon='{"background": "#FF6B6B", "content": "🔧"}', + icon='{"background": "#FF6B6B", "content": "🔧"}', tenant_id="test_tenant_id", user_id="test_user_id", server_url="https://mcp.example.com", + server_url_hash="test_server_url_hash", server_identifier="test_server", tools='[{"name": "test_tool", "description": "Test tool"}]', authed=True, @@ -658,7 +671,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -694,7 +707,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -730,7 +743,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -785,3 +798,192 @@ class TestToolTransformService: assert result is not None assert result == mock_controller mock_from_db.assert_called_once_with(provider) + + +def _mock_tool(*, base_params, runtime_params): + """Helper to build a Mock tool with real entity objects. + + Tool is abstract and requires runtime behaviour (fork_tool_runtime, + get_runtime_parameters), so it stays as a Mock. Everything else uses + real Pydantic instances. + """ + entity = ToolEntity( + identity=ToolIdentity( + author="test_author", + name="test_tool", + label=I18nObject(en_US="Test Tool"), + provider="test_provider", + ), + parameters=base_params or [], + description=ToolDescription( + human=I18nObject(en_US="Test description"), + llm="Test description for LLM", + ), + output_schema={}, + ) + mock_tool = Mock(spec=Tool) + mock_tool.entity = entity + mock_tool.get_runtime_parameters.return_value = runtime_params + mock_tool.fork_tool_runtime.return_value = mock_tool + return mock_tool + + +def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None): + return ToolParameter( + name=name, + label=I18nObject(en_US=label or name), + human_description=I18nObject(en_US=name), + type=ToolParameter.ToolParameterType.STRING, + form=form, + ) + + +class TestConvertToolEntityToApiEntity: + """Tests for ToolTransformService.convert_tool_entity_to_api_entity.""" + + def test_parameter_override(self): + base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")] + runtime = [_param("param1", label="Runtime 1")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 2 + assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1" + assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2" + + def test_additional_runtime_parameters(self): + base = [_param("param1", label="Base 1")] + runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 2 + names = [p.name for p in result.parameters] + assert "param1" in names + assert "runtime_only" in names + + def test_non_form_runtime_parameters_excluded(self): + base = [_param("param1")] + runtime = [ + _param("param1", label="Runtime 1"), + _param("llm_param", form=ToolParameter.ToolParameterForm.LLM), + ] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 1 + assert result.parameters[0].name == "param1" + + def test_empty_parameters(self): + tool = _mock_tool(base_params=[], runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_none_parameters(self): + tool = _mock_tool(base_params=None, runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_parameter_order_preserved(self): + base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")] + runtime = [_param("p2", label="R2"), _param("p4", label="R4")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"] + assert result.parameters[1].label.en_US == "R2" + + +class TestWorkflowProviderToUserProvider: + """Tests for ToolTransformService.workflow_provider_to_user_provider.""" + + @staticmethod + def _make_controller(provider_id="provider_123", **identity_overrides): + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + defaults = { + "author": "test_author", + "name": "test_workflow_tool", + "description": I18nObject(en_US="Test description"), + "icon": '{"type": "emoji", "content": "🔧"}', + "icon_dark": None, + "label": I18nObject(en_US="Test Workflow Tool"), + } + defaults.update(identity_overrides) + identity = ToolProviderIdentity(**defaults) + entity = ToolProviderEntity(identity=identity) + return WorkflowToolProviderController(entity=entity, provider_id=provider_id) + + def test_with_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1", "l2"], + workflow_app_id="app_123", + ) + + assert isinstance(result, ToolProviderApiEntity) + assert result.id == "provider_123" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_123" + assert result.labels == ["l1", "l2"] + assert result.is_team_authorization is True + + def test_without_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1"], + ) + + assert result.workflow_app_id is None + + def test_workflow_app_id_none_explicit(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=None, + workflow_app_id=None, + ) + + assert result.workflow_app_id is None + assert result.labels == [] + + def test_preserves_other_fields(self): + ctrl = self._make_controller( + "provider_456", + author="another_author", + name="another_workflow_tool", + description=I18nObject(en_US="Another desc", zh_Hans="Another desc"), + icon='{"type": "emoji", "content": "⚙️"}', + icon_dark='{"type": "emoji", "content": "🔧"}', + label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"), + ) + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["automation"], + workflow_app_id="app_456", + ) + + assert result.id == "provider_456" + assert result.author == "another_author" + assert result.name == "another_workflow_tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_456" + assert result.is_team_authorization is True + assert result.allow_delete is True diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e54..e3c0749494 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService: # After the fix, this should always be 0 # For now, we document that the record may exist, demonstrating the bug # assert tool_count == 0 # Expected after fix + + def test_delete_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of a workflow tool.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + tool_name = fake.unique.word() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + ) + + tool = ( + db_session_with_containers.query(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name) + .first() + ) + assert tool is not None + + result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first() + ) + assert deleted is None + + def test_list_tenant_workflow_tools_empty( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing workflow tools when none exist returns empty list.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id) + + assert result == [] + + def test_get_workflow_tool_by_tool_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_tool_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_get_workflow_tool_by_app_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_app_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_list_single_workflow_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that list_single_workflow_tools raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4()) + + def test_create_workflow_tool_with_labels( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that labels are forwarded to ToolLabelManager when provided.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.unique.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + labels=["label-1", "label-2"], + ) + + assert result == {"result": "success"} + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8c007877fd..c3fe6a2950 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -510,7 +510,7 @@ class TestWorkflowConverter: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=10, score_threshold=0.8, - reranking_model={"provider": "cohere", "model": "rerank-v2"}, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, reranking_enabled=True, ), ) @@ -543,8 +543,8 @@ class TestWorkflowConverter: multiple_config = node["data"]["multiple_retrieval_config"] assert multiple_config["top_k"] == 10 assert multiple_config["score_threshold"] == 0.8 - assert multiple_config["reranking_model"]["provider"] == "cohere" - assert multiple_config["reranking_model"]["model"] == "rerank-v2" + assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere" + assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2" # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py new file mode 100644 index 0000000000..29e1e240b4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -0,0 +1,158 @@ +"""Testcontainers integration tests for WorkflowService.delete_workflow.""" + +import json +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + + +class TestWorkflowDeletion: + def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + session.add(tenant) + session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"wf_del_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + session.add(join) + session.flush() + return tenant, account + + def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App: + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + workflow_id=workflow_id, + ) + session.add(app) + session.flush() + return app + + def _create_workflow( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0" + ) -> Workflow: + workflow = Workflow( + id=str(uuid4()), + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + session.add(workflow) + session.flush() + return workflow + + def _create_tool_provider( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str + ) -> WorkflowToolProvider: + provider = WorkflowToolProvider( + name=f"tool-{uuid4()}", + label=f"Tool {uuid4()}", + icon="wrench", + app_id=app.id, + version=version, + user_id=account.id, + tenant_id=tenant.id, + description="test tool provider", + ) + session.add(provider) + session.flush() + return provider + + def test_delete_workflow_success(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + db_session_with_containers.commit() + workflow_id = workflow.id + + service = WorkflowService(sessionmaker(bind=db.engine)) + result = service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id + ) + + assert result is True + db_session_with_containers.expire_all() + assert db_session_with_containers.get(Workflow, workflow_id) is None + + def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="draft" + ) + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(DraftWorkflowDeletionError): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + # Point app to this workflow + app.workflow_id = workflow.id + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="currently in use by app"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0") + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="published as a tool"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bf..4b04c1accb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6adefd59be..6cbbe43137 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -13,6 +13,8 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -151,7 +153,7 @@ class TestBatchCleanDocumentTask: created_from=DocumentCreatedFrom.WEB, created_by=account.id, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) @@ -209,7 +211,7 @@ class TestBatchCleanDocumentTask: upload_file = UploadFile( tenant_id=account.current_tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{fake.file_name()}", name=fake.file_name(), size=1024, @@ -391,7 +393,12 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Execute the task with non-existent dataset - batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) + batch_clean_document_task( + document_ids=[document_id], + dataset_id=dataset_id, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + file_ids=[], + ) # Verify that no index processing occurred mock_external_service_dependencies["index_processor"].clean.assert_not_called() @@ -524,7 +531,11 @@ class TestBatchCleanDocumentTask: account = self._create_test_account(db_session_with_containers) # Test different doc_form types - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index ebe5ff1d96..d2e343ef52 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,6 +19,8 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -140,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, @@ -178,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ) @@ -203,7 +205,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{fake.file_name()}", name=fake.file_name(), size=1024, @@ -220,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask: return upload_file - def _create_test_csv_content(self, content_type="text_model"): + def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX): """ Helper method to create test CSV content. Args: - content_type: Type of content to create ("text_model" or "qa_model") + content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX) Returns: str: CSV content as string """ - if content_type == "qa_model": + if content_type == IndexStructureType.QA_INDEX: csv_content = "content,answer\n" csv_content += "This is the first segment content,This is the first answer\n" csv_content += "This is the second segment content,This is the second answer\n" @@ -263,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] @@ -450,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Archived document @@ -466,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Document with incomplete indexing @@ -482,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), ] @@ -654,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask: db_session_with_containers.commit() # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 638752cf8b..1dd37fbc92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,6 +18,8 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -152,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -191,7 +193,7 @@ class TestCleanDatasetTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=100, created_at=datetime.now(), updated_at=datetime.now(), @@ -254,7 +256,7 @@ class TestCleanDatasetTask: upload_file = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{fake.file_name()}", name=fake.file_name(), size=1024, @@ -868,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -925,7 +927,7 @@ class TestCleanDatasetTask: special_filename = f"test_file_{special_content}.txt" upload_file = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{special_filename}", name=special_filename, size=1024, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index a2a190fd69..926c839c8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask: name=f"Notion Page {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", # Set doc_form to ensure dataset.doc_form works + doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works doc_language="en", indexing_status=IndexingStatus.COMPLETED, ) @@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask: # Test different index types # Note: Only testing text_model to avoid dependency on external services - index_types = ["text_model"] + index_types = [IndexStructureType.PARAGRAPH_INDEX] for index_type in index_types: # Create dataset (doc_form will be set via document creation) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 132f43c320..9f8e37fc9e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -120,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, @@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask: - Processing completes successfully for different forms """ # Arrange: Test different doc_forms - doc_forms = ["qa_model", "text_model", "web_model"] + doc_forms = [ + IndexStructureType.QA_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + ] for doc_form in doc_forms: # Create fresh test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc7011..13ea94348a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index e80b37ac1b..d457b59d58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="qa_index", + doc_form=IndexStructureType.QA_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type - mock_index_processor_factory.assert_called_once_with("qa_index") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type - mock_index_processor_factory.assert_called_once_with("text_model") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask: name=f"Test Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask: name="Enabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask: name="Disabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped @@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask: name="Active Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask: name="Archived Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask: name="Completed Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask: name="Incomplete Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9c..8a69707b38 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index da42fc7167..5bdf7d1389 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,6 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -99,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask: dataset: Dataset, tenant: Tenant, account: Account, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """ Helper method to create a test document. @@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask: - Index processor clean method is called correctly """ # Test different document forms - doc_forms = ["text_model", "qa_model", "table_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Arrange: Create test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 4bc9bb4749..3e9a0c8f7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -102,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", @@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask: document.indexing_status = "completed" document.enabled = True document.archived = False - document.doc_form = "text_model" # Use text_model form for testing + document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing document.doc_language = "en" db_session_with_containers.add(document) db_session_with_containers.commit() @@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask: segment_ids = [segment.id for segment in segments] # Test different document forms - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Update document form diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 6a17a19a54..d4021143ef 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,6 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -56,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory: created_by=created_by, indexing_status=indexing_status, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b07285..cf1a8666f3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 2fbea1388c..d94abf2b40 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -63,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index f1f5a4b105..6a8e186958 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -109,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -244,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=dataset.created_by, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) extra_documents.append(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8..e2f35067e3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 182c9ef882..5bded4d670 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ import pytest from core.db.session_factory import session_factory from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import SegmentType +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole @@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: for i in range(count): upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test/file-{uuid.uuid4()}-{i}.json", name=f"file-{i}.json", size=1024 + i, diff --git a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py new file mode 100644 index 0000000000..34a1941c39 --- /dev/null +++ b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from extensions.storage.opendal_storage import OpenDALStorage + + +class TestOpenDALFsDefaultRoot: + """Test that OpenDALStorage with scheme='fs' works correctly when no root is provided.""" + + def test_fs_without_root_uses_default(self, tmp_path, monkeypatch): + """When no root is specified, the default 'storage' should be used and passed to the Operator.""" + # Change to tmp_path so the default "storage" dir is created there + monkeypatch.chdir(tmp_path) + # Ensure no OPENDAL_FS_ROOT env var is set + monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False) + + storage = OpenDALStorage(scheme="fs") + + # The default directory should have been created + assert (tmp_path / "storage").is_dir() + # The storage should be functional + storage.save("test_default_root.txt", b"hello") + assert storage.exists("test_default_root.txt") + assert storage.load_once("test_default_root.txt") == b"hello" + + # Cleanup + storage.delete("test_default_root.txt") + + def test_fs_with_explicit_root(self, tmp_path): + """When root is explicitly provided, it should be used.""" + custom_root = str(tmp_path / "custom_storage") + storage = OpenDALStorage(scheme="fs", root=custom_root) + + assert Path(custom_root).is_dir() + storage.save("test_explicit_root.txt", b"world") + assert storage.exists("test_explicit_root.txt") + assert storage.load_once("test_explicit_root.txt") == b"world" + + # Cleanup + storage.delete("test_explicit_root.txt") + + def test_fs_with_env_var_root(self, tmp_path, monkeypatch): + """When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs.""" + env_root = str(tmp_path / "env_storage") + monkeypatch.setenv("OPENDAL_FS_ROOT", env_root) + # Ensure .env file doesn't interfere + monkeypatch.chdir(tmp_path) + + storage = OpenDALStorage(scheme="fs") + + assert Path(env_root).is_dir() + storage.save("test_env_root.txt", b"env_data") + assert storage.exists("test_env_root.txt") + assert storage.load_once("test_env_root.txt") == b"env_data" + + # Cleanup + storage.delete("test_env_root.txt") diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fe..1d1e119fd6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" @@ -281,12 +328,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +350,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c332..11b3b3470d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc..f588ab261d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da3..e64c508b82 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c7..a0e2edb8cf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index f100080eaa..0e22db9f9b 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -129,6 +129,136 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) handler(api, app_model=SimpleNamespace(id="app")) +def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + unique_hash="restored-hash", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + response = handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert response["result"] == "success" + assert response["hash"] == "restored-hash" + + +def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.WorkflowNotFoundError("Workflow not found") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(NotFound): + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + +def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.IsDraftWorkflowError( + "Cannot use draft workflow version. Workflow ID: draft-workflow. " + "Please use a published workflow version or leave workflow_id empty." + ) + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="draft-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE + + +def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + ValueError("invalid workflow graph") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == "invalid workflow graph" + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492da..b5f751f5a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 3b8679f4ec..ebbb34e069 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -59,6 +59,44 @@ class TestPipelineTemplateDetailApi: assert status == 200 assert response == template + def test_get_returns_404_when_template_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + + def test_get_returns_404_for_customized_type_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=customized"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + class TestCustomizedPipelineTemplateApi: def test_patch_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 7775cbdd81..472d133349 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -2,7 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import Forbidden, HTTPException, NotFound import services from controllers.console import console_ns @@ -19,13 +19,14 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineDraftNodeRunApi, RagPipelineDraftRunIterationNodeApi, RagPipelineDraftRunLoopNodeApi, + RagPipelineDraftWorkflowRestoreApi, RagPipelineRecommendedPluginApi, RagPipelineTaskStopApi, RagPipelineTransformApi, RagPipelineWorkflowLastRunApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -116,6 +117,86 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 + def test_restore_published_workflow_to_draft_success(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1)) + + service = MagicMock() + service.restore_published_workflow_to_draft.return_value = workflow + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "published-workflow") + + assert result["result"] == "success" + assert result["hash"] == "restored-hash" + + def test_restore_published_workflow_to_draft_not_found(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found") + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "published-workflow") + + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError( + "source workflow must be published" + ) + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(HTTPException) as exc: + method(api, pipeline, "draft-workflow") + + assert exc.value.code == 400 + assert exc.value.description == "source workflow must be published" + class TestDraftRunNodes: def test_iteration_node_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adf..d841f67f9b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import ( DataSourceNotionDocumentSyncApi, DataSourceNotionListApi, ) +from core.rag.index_processor.constant.index_type import IndexStructureType def unwrap(func): @@ -343,7 +344,7 @@ class TestDataSourceNotionApi: } ], "process_rule": {"rules": {}}, - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index f9fc2ac397..ff565f19fd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -28,6 +28,8 @@ from controllers.console.datasets.datasets import ( from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import ApiToken, UploadFile from services.dataset_service import DatasetPermissionService, DatasetService @@ -1121,7 +1123,7 @@ class TestDatasetIndexingEstimateApi: def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile: upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key="key", name="name.txt", size=1, @@ -1145,7 +1147,7 @@ class TestDatasetIndexingEstimateApi: }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } @@ -1474,8 +1476,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), ): response, status = method(api, "dataset-1") @@ -1524,13 +1526,6 @@ class TestDatasetIndexingStatusApi: document.error = None document.stopped_at = None - # First count = completed segments, second = total segments - query_mock = MagicMock() - query_mock.where.side_effect = [ - MagicMock(count=lambda: 2), - MagicMock(count=lambda: 5), - ] - with ( app.test_request_context("/"), patch( @@ -1542,8 +1537,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets.db.session.scalar", + side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") @@ -1589,8 +1584,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken.generate_api_key", @@ -1623,8 +1618,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: @@ -1651,8 +1646,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", @@ -1679,8 +1674,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f23dd5b44a..ce2278de4f 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType, IndexingStatus @@ -66,7 +67,7 @@ def document(): indexing_status=IndexingStatus.INDEXING, data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, archived=False, is_paused=False, dataset_process_rule=None, @@ -139,8 +140,8 @@ class TestDatasetDocumentListApi: return_value=pagination, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=2, ), patch( "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", @@ -699,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=log, ), ): response, status = method(api, "ds-1", "doc-1") @@ -765,8 +764,8 @@ class TestDocumentGenerateSummaryApi: summary_index_setting={"enable": True}, ) - doc1 = MagicMock(id="doc-1", doc_form="qa_model") - doc2 = MagicMock(id="doc-2", doc_form="text") + doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX) + doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX) payload = {"document_list": ["doc-1", "doc-2"]} @@ -822,19 +821,16 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) - query_mock = MagicMock() - query_mock.where.return_value.first.return_value = None - with ( app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -849,7 +845,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -862,10 +858,8 @@ class TestDocumentIndexingEstimateApi: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", @@ -973,7 +967,7 @@ class TestDocumentBatchIndexingEstimateApi: "mode": "single", "only_main_content": True, }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1001,7 +995,7 @@ class TestDocumentBatchIndexingEstimateApi: "notion_page_id": "p1", "type": "page", }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1024,7 +1018,7 @@ class TestDocumentBatchIndexingEstimateApi: indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): @@ -1238,12 +1232,8 @@ class TestDocumentPermissionCases: return_value=None, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=lambda *a: MagicMock( - order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) - ) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=process_rule, ), ): result = method(api) @@ -1353,7 +1343,7 @@ class TestDocumentIndexingEdgeCases: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -1363,8 +1353,8 @@ class TestDocumentIndexingEdgeCases: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index e67e4daad9..306a772fd1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -24,6 +24,7 @@ from controllers.console.datasets.error import ( InvalidActionError, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() segment.id = "seg-1" @@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() @@ -525,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -620,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -705,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -737,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(ValueError): @@ -769,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -830,8 +831,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -879,8 +880,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -923,11 +924,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -969,11 +967,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -1179,8 +1174,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(NotFound): @@ -1214,8 +1209,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py index 90f00711c1..e358435de4 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -26,12 +26,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = None - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=None, ) with pytest.raises(PipelineNotFoundError): @@ -51,12 +48,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -76,12 +70,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -100,18 +91,15 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - def where_side_effect(*args, **kwargs): - assert args[0].right.value == "123" - return Mock(first=lambda: pipeline) - - mock_query = Mock() - mock_query.where.side_effect = where_side_effect - - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + mock_scalar = mocker.patch( + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id=123) assert result is pipeline + # Verify the pipeline_id was cast to string in the where clause + stmt = mock_scalar.call_args[0][0] + where_clauses = stmt.whereclause.clauses + assert where_clauses[0].right.value == "123" diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index 0606219356..c8f674f515 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -2,6 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import controllers.console.explore.banner as banner_module +from models.enums import BannerStatus def unwrap(func): @@ -20,16 +21,11 @@ class TestBannerApi: banner.content = {"text": "hello"} banner.link = "https://example.com" banner.sort = 1 - banner.status = "enabled" + banner.status = BannerStatus.ENABLED banner.created_at = datetime(2024, 1, 1) - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.return_value = [banner] - session = MagicMock() - session.query.return_value = query + session.scalars.return_value.all.return_value = [banner] with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session): result = method(api) @@ -54,19 +50,17 @@ class TestBannerApi: banner.content = {"text": "fallback"} banner.link = None banner.sort = 1 - banner.status = "enabled" + banner.status = BannerStatus.ENABLED banner.created_at = None - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.side_effect = [ + scalars_result = MagicMock() + scalars_result.all.side_effect = [ [], [banner], ] session = MagicMock() - session.query.return_value = query + session.scalars.return_value = scalars_result with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session): result = method(api) @@ -86,13 +80,8 @@ class TestBannerApi: api = banner_module.BannerApi() method = unwrap(api.get) - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.return_value = [] - session = MagicMock() - session.query.return_value = query + session.scalars.return_value.all.return_value = [] with app.test_request_context("/"), patch.object(banner_module.db, "session", session): result = method(api) diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py index 3983a6a97e..93652e75d2 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi: app_entity.tenant_id = "t2" session = MagicMock() - session.query.return_value.where.return_value.first.side_effect = [ - recommended, - app_entity, - None, - ] + # scalar() is called for recommended_app and installed_app lookups + session.scalar.side_effect = [recommended, None] + # get() is called for app PK lookup + session.get.return_value = app_entity with ( app.test_request_context("/", json={"app_id": "a1"}), @@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi: method = unwrap(api.post) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with ( app.test_request_context("/", json={"app_id": "a1"}), @@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi: app_entity = MagicMock(is_public=False) session = MagicMock() - session.query.return_value.where.return_value.first.side_effect = [ - recommended, - app_entity, - ] + # scalar() returns recommended_app + session.scalar.return_value = recommended + # get() returns the app entity + session.get.return_value = app_entity with ( app.test_request_context("/", json={"app_id": "a1"}), diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index d85114c8fb..5a03daecbc 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -958,8 +958,8 @@ class TestTrialSitApi: app_model = MagicMock() app_model.id = "a1" - with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = None + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = None with pytest.raises(Forbidden): method(api, app_model) @@ -973,8 +973,8 @@ class TestTrialSitApi: app_model.tenant = MagicMock() app_model.tenant.status = TenantStatus.ARCHIVE - with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = site + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = site with pytest.raises(Forbidden): method(api, app_model) @@ -990,10 +990,10 @@ class TestTrialSitApi: with ( app.test_request_context("/"), - patch.object(module.db.session, "query") as mock_query, + patch.object(module.db.session, "scalar") as mock_scalar, patch.object(module.SiteResponse, "model_validate") as mock_validate, ): - mock_query.return_value.where.return_value.first.return_value = site + mock_scalar.return_value = site mock_validate_result = MagicMock() mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"} mock_validate.return_value = mock_validate_result diff --git a/api/tests/unit_tests/controllers/console/explore/test_wraps.py b/api/tests/unit_tests/controllers/console/explore/test_wraps.py index 67e7a32591..2c1acfc3d6 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/explore/test_wraps.py @@ -34,9 +34,9 @@ def test_installed_app_required_not_found(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = None + scalar_mock.return_value = None with pytest.raises(NotFound): view("app-id") @@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, patch("controllers.console.explore.wraps.db.session.delete"), patch("controllers.console.explore.wraps.db.session.commit"), ): - q.return_value.where.return_value.first.return_value = installed_app + scalar_mock.return_value = installed_app with pytest.raises(NotFound): view("app-id") @@ -76,9 +76,9 @@ def test_installed_app_required_success(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = installed_app + scalar_mock.return_value = installed_app result = view("app-id") assert result == installed_app @@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = None + scalar_mock.return_value = None with pytest.raises(TrialAppNotAllowed): view("app-id") @@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.side_effect = [ + scalar_mock.side_effect = [ trial_app, record, ] @@ -194,9 +194,9 @@ def test_trial_app_required_success(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.side_effect = [ + scalar_mock.side_effect = [ trial_app, record, ] diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 769edc8d1c..e89b89c8b1 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -11,6 +11,7 @@ from controllers.console.tag.tags import ( TagListApi, TagUpdateDeleteApi, ) +from models.enums import TagType def unwrap(func): @@ -52,7 +53,7 @@ def tag(): tag = MagicMock() tag.id = "tag-1" tag.name = "test-tag" - tag.type = "knowledge" + tag.type = TagType.KNOWLEDGE return tag diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index 018257f815..2dff9c4037 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -8,6 +8,7 @@ from controllers.console.apikey import ( BaseApiKeyResource, _get_resource, ) +from models.enums import ApiTokenType @pytest.fixture @@ -45,14 +46,14 @@ def bypass_permissions(): class DummyApiKeyListResource(BaseApiKeyListResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" token_prefix = "app-" class DummyApiKeyResource(BaseApiKeyResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" @@ -114,7 +115,7 @@ class TestBaseApiKeyResource: def test_delete_key_not_found(self, tenant_context_admin, db_mock): resource = DummyApiKeyResource() - db_mock.session.query.return_value.where.return_value.first.return_value = None + db_mock.session.scalar.return_value = None with patch("controllers.console.apikey._get_resource"): with pytest.raises(Exception) as exc_info: @@ -125,7 +126,7 @@ class TestBaseApiKeyResource: def test_delete_success(self, tenant_context_admin, db_mock): resource = DummyApiKeyResource() - db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock() + db_mock.session.scalar.return_value = MagicMock() with ( patch("controllers.console.apikey._get_resource"), diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 6777077de8..f6e096a97b 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -328,7 +328,7 @@ class TestSystemSetup: def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = "some_password" @setup_required @@ -345,7 +345,7 @@ class TestSystemSetup: def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): """Test NotSetupError when no INIT_PASSWORD and setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = None # No INIT_PASSWORD @setup_required diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index 00d322fdea..42be02cdaf 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -55,9 +55,9 @@ class TestAccountInitApi: patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), patch("controllers.console.workspace.account.db.session.commit", return_value=None), patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), - patch("controllers.console.workspace.account.db.session.query") as query_mock, + patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock, ): - query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused") + scalar_mock.return_value = MagicMock(status="unused") resp = method(api) assert resp["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index b6708d1f6f..718b57ba6b 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -207,10 +207,10 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 200 @@ -226,9 +226,9 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, ): - q.return_value.where.return_value.first.return_value = None + get_mock.return_value = None with pytest.raises(HTTPException): method(api, "x") @@ -244,13 +244,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.CannotOperateSelfError("x"), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 400 @@ -266,13 +266,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.NoPermissionError("x"), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 403 @@ -288,13 +288,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.MemberNotInTenantError(), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 404 diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index 06f666fa60..f5ebe0b534 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -36,7 +36,115 @@ def unwrap(func): class TestTenantListApi: - def test_get_success(self, app): + def test_get_success_saas_path(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={ + "t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}, + "t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0}, + }, + ) as get_plan_bulk_mock, + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert len(result["workspaces"]) == 2 + assert result["workspaces"][0]["current"] is True + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + get_features_mock.assert_not_called() + + def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app): + """Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used. + + billing.enabled is mocked False to prove the endpoint does not gate on it for this path + (SaaS contract treats enabled as on; display follows subscription.plan). + """ + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + features_t2 = MagicMock() + features_t2.billing.enabled = False + features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}}, + ) as get_plan_bulk_mock, + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features_t2, + ) as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + get_features_mock.assert_called_once_with("t2") + + def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app): + """Test fallback to FeatureService when bulk billing returns empty result. + + BillingService.get_plan_bulk catches exceptions internally and returns empty dict, + so we simulate the real failure mode by returning empty dict for non-empty input. + """ api = TenantListApi() method = unwrap(api.get) @@ -54,27 +162,41 @@ class TestTenantListApi: ) features = MagicMock() - features.billing.enabled = True - features.billing.subscription.plan = CloudPlan.SANDBOX + features.billing.enabled = False + features.billing.subscription.plan = CloudPlan.TEAM with ( app.test_request_context("/workspaces"), patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant1, tenant2], ), - patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={}, # Simulates real failure: empty result for non-empty input + ) as get_plan_bulk_mock, + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features, + ) as get_features_mock, + patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock, ): result, status = method(api) assert status == 200 - assert len(result["workspaces"]) == 2 - assert result["workspaces"][0]["current"] is True + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.TEAM + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + assert get_features_mock.call_count == 2 + logger_warning_mock.assert_called_once() - def test_get_billing_disabled(self, app): + def test_get_billing_disabled_community_path(self, app): api = TenantListApi() method = unwrap(api.get) @@ -87,6 +209,7 @@ class TestTenantListApi: features = MagicMock() features.billing.enabled = False + features.billing.subscription.plan = CloudPlan.SANDBOX with ( app.test_request_context("/workspaces"), @@ -98,15 +221,83 @@ class TestTenantListApi: "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant], ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), patch( "controllers.console.workspace.workspace.FeatureService.get_features", return_value=features, - ), + ) as get_features_mock, ): result, status = method(api) assert status == 200 assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + get_features_mock.assert_called_once_with("t1") + + def test_get_enterprise_only_skips_feature_service(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX + assert result["workspaces"][0]["current"] is False + assert result["workspaces"][1]["current"] is True + get_features_mock.assert_not_called() + + def test_get_enterprise_only_with_empty_tenants(self, app): + api = TenantListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None) + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"] == [] + get_features_mock.assert_not_called() class TestWorkspaceListApi: @@ -258,12 +449,12 @@ class TestSwitchWorkspaceApi: "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), - patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"} ), ): - query_mock.return_value.get.return_value = tenant + get_mock.return_value = tenant result = method(api) assert result["result"] == "success" @@ -297,9 +488,9 @@ class TestSwitchWorkspaceApi: return_value=(MagicMock(), "t1"), ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), - patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, ): - query_mock.return_value.get.return_value = None + get_mock.return_value = None with pytest.raises(ValueError): method(api) diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 6de07a23e5..eac57fe4b7 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -50,7 +50,7 @@ class TestGetUser: mock_user.id = "user123" mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = mock_user + mock_session.get.return_value = mock_user # Act with app.app_context(): @@ -58,7 +58,7 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.query.assert_called_once() + mock_session.get.assert_called_once() @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.Session") @@ -72,7 +72,8 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = mock_user + # non-anonymous path uses session.get(); anonymous uses session.scalar() + mock_session.get.return_value = mock_user # Act with app.app_context(): @@ -89,7 +90,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = None + mock_session.get.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user @@ -103,18 +104,20 @@ class TestGetUser: mock_session.commit.assert_called_once() mock_session.refresh.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.Session") @patch("controllers.inner_api.plugin.wraps.db") def test_should_use_default_session_id_when_user_id_none( - self, mock_db, mock_session_class, mock_enduser_class, app: Flask + self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask ): """Test using default session ID when user_id is None""" # Arrange mock_user = MagicMock() mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.where.return_value.first.return_value = mock_user + # When user_id is None, is_anonymous=True, so session.scalar() is used + mock_session.scalar.return_value = mock_user # Act with app.app_context(): @@ -133,7 +136,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.query.side_effect = Exception("Database error") + mock_session.get.side_effect = Exception("Database error") # Act & Assert with app.app_context(): @@ -161,9 +164,9 @@ class TestGetUserTenant: # Act with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}): monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) - with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_get_user.return_value = mock_user result = protected_view() @@ -194,8 +197,8 @@ class TestGetUserTenant: # Act & Assert with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}): - with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = None + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: + mock_get.return_value = None with pytest.raises(ValueError, match="tenant not found"): protected_view() @@ -215,9 +218,9 @@ class TestGetUserTenant: # Act - use empty string for user_id to trigger default logic with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}): monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) - with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_get_user.return_value = mock_user result = protected_view() diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py index 883ccdea2c..efe1841f08 100644 --- a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -249,8 +249,8 @@ class TestEnterpriseInnerApiUserAuth: headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} ): with patch.object(dify_config, "INNER_API", True): - with patch("controllers.inner_api.wraps.db.session.query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = mock_user + with patch("controllers.inner_api.wraps.db.session.get") as mock_get: + mock_get.return_value = mock_user result = protected_view() # Assert diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 4fbf0f7125..56a8f94963 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -91,7 +91,7 @@ class TestEnterpriseWorkspace: # Arrange mock_account = MagicMock() mock_account.email = "owner@example.com" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account now = datetime(2025, 1, 1, 12, 0, 0) mock_tenant = MagicMock() @@ -122,7 +122,7 @@ class TestEnterpriseWorkspace: def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask): """Test that post() returns 404 when the owner account does not exist""" # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act unwrapped_post = inspect.unwrap(api_instance.post) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index 4de12de829..c2b8aed1ae 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -31,6 +31,7 @@ from controllers.service_api.app.message import ( MessageListQuery, MessageSuggestedApi, ) +from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( @@ -310,7 +311,7 @@ class TestMessageService: app_model=Mock(spec=App), message_id=str(uuid.uuid4()), user=Mock(spec=EndUser), - rating="like", + rating=FeedbackRating.LIKE, content="Great response!", ) @@ -326,7 +327,7 @@ class TestMessageService: app_model=Mock(spec=App), message_id="invalid_message_id", user=Mock(spec=EndUser), - rating="like", + rating=FeedbackRating.LIKE, content=None, ) diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 4337a0c8c0..01d2d1e7c0 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -12,6 +12,7 @@ from unittest.mock import Mock import pytest from flask import Flask +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -175,7 +176,7 @@ def mock_document(): document.name = "test_document.txt" document.indexing_status = "completed" document.enabled = True - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX return document diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 7cb2f1050c..8fe41cd19f 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import ( from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.tag_service import TagService @@ -277,7 +278,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Test Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "0" # Required for Pydantic validation - must be string mock_tag_service.get_tags.return_value = [mock_tag] @@ -316,7 +317,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "new_tag_1" mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag_service.save_tags.return_value = mock_tag mock_service_api_ns.payload = {"name": "New Tag"} @@ -378,7 +379,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Updated Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "5" mock_tag_service.update_tags.return_value = mock_tag mock_tag_service.get_tag_binding_count.return_value = 5 @@ -866,7 +867,7 @@ class TestTagService: mock_tag = Mock() mock_tag.id = str(uuid.uuid4()) mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_save.return_value = mock_tag result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5c48ef1804..73a87761d5 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import ( SegmentCreatePayload, SegmentListQuery, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -788,7 +789,7 @@ class TestSegmentApiGet: # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc_svc.get_document.return_value = Mock(doc_form="text_model") + mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -903,7 +904,7 @@ class TestSegmentApiPost: mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.segment_create_args_validate.return_value = None @@ -1091,7 +1092,7 @@ class TestDatasetSegmentApiDelete: mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = None # Segment not found @@ -1371,7 +1372,7 @@ class TestDatasetSegmentApiGetSingle: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None - mock_doc = Mock(doc_form="text_model") + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} @@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle: assert status == 200 assert "data" in response - assert response["doc_form"] == "text_model" + assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index e6e841be19..7f77e61ee4 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload: def test_payload_with_defaults(self): """Test payload default values.""" payload = DocumentTextCreatePayload(name="Doc", text="Content") - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" assert payload.process_rule is None assert payload.indexing_technique is None @@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload: payload = DocumentTextCreatePayload( name="Full Document", text="Complete document content here", - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, doc_language="Chinese", indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) assert payload.name == "Full Document" - assert payload.doc_form == "qa_model" + assert payload.doc_form == IndexStructureType.QA_INDEX assert payload.doc_language == "Chinese" assert payload.indexing_technique == "high_quality" assert payload.embedding_model == "text-embedding-ada-002" @@ -147,8 +148,8 @@ class TestDocumentTextUpdate: def test_payload_with_doc_form_update(self): """Test payload with doc_form update.""" - payload = DocumentTextUpdate(doc_form="qa_model") - assert payload.doc_form == "qa_model" + payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX) + assert payload.doc_form == IndexStructureType.QA_INDEX def test_payload_with_language_update(self): """Test payload with doc_language update.""" @@ -158,7 +159,7 @@ class TestDocumentTextUpdate: def test_payload_default_values(self): """Test payload default values.""" payload = DocumentTextUpdate() - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" @@ -272,14 +273,24 @@ class TestDocumentDocForm: def test_text_model_form(self): """Test text_model form.""" - doc_form = "text_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.PARAGRAPH_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms def test_qa_model_form(self): """Test qa_model form.""" - doc_form = "qa_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.QA_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms @@ -504,7 +515,7 @@ class TestDocumentApiGet: doc.name = "test_document.txt" doc.indexing_status = "completed" doc.enabled = True - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.doc_type = "book" doc.doc_metadata_details = {"source": "upload"} diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 61fce3ed97..95c2f5cf92 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -39,14 +39,21 @@ class TestHitTestingPayload: def test_payload_with_all_fields(self): """Test payload with all optional fields.""" + retrieval_model_data = { + "search_method": "semantic_search", + "reranking_enable": False, + "score_threshold_enabled": False, + "top_k": 5, + } payload = HitTestingPayload( query="test query", - retrieval_model={"top_k": 5}, + retrieval_model=retrieval_model_data, external_retrieval_model={"provider": "openai"}, attachment_ids=["att_1", "att_2"], ) assert payload.query == "test query" - assert payload.retrieval_model == {"top_k": 5} + assert payload.retrieval_model is not None + assert payload.retrieval_model.top_k == 5 assert payload.external_retrieval_model == {"provider": "openai"} assert payload.attachment_ids == ["att_1", "att_2"] @@ -134,7 +141,13 @@ class TestHitTestingApiPost: mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8} + retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "score_threshold_enabled": True, + "top_k": 10, + "score_threshold": 0.8, + } mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []} mock_hit_svc.hit_testing_args_check.return_value = None @@ -152,7 +165,11 @@ class TestHitTestingApiPost: assert response["query"] == "complex query" call_kwargs = mock_hit_svc.retrieve.call_args - assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model + # retrieval_model is serialized via model_dump, verify key fields + passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model") + assert passed_retrieval_model is not None + assert passed_retrieval_model["search_method"] == "semantic_search" + assert passed_retrieval_model["top_k"] == 10 @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") diff --git a/api/tests/unit_tests/controllers/trigger/test_webhook.py b/api/tests/unit_tests/controllers/trigger/test_webhook.py index d633365f2b..91c793d292 100644 --- a/api/tests/unit_tests/controllers/trigger/test_webhook.py +++ b/api/tests/unit_tests/controllers/trigger/test_webhook.py @@ -23,6 +23,7 @@ def mock_jsonify(): class DummyWebhookTrigger: webhook_id = "wh-1" + webhook_url = "http://localhost:5001/triggers/webhook/wh-1" tenant_id = "tenant-1" app_id = "app-1" node_id = "node-1" @@ -104,7 +105,32 @@ class TestHandleWebhookDebug: @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") @patch.object(module.WebhookService, "extract_and_validate_webhook_data") @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) - @patch.object(module.TriggerDebugEventBus, "dispatch") + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0) + def test_debug_requires_active_listener( + self, + mock_dispatch, + mock_build_inputs, + mock_extract, + mock_get, + ): + mock_get.return_value = (DummyWebhookTrigger(), None, "node_config") + mock_extract.return_value = {"method": "POST"} + + response, status = module.handle_webhook_debug("wh-1") + + assert status == 409 + assert response["error"] == "No active debug listener" + assert response["message"] == ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ) + assert response["execution_url"] == DummyWebhookTrigger.webhook_url + mock_dispatch.assert_called_once() + + @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") + @patch.object(module.WebhookService, "extract_and_validate_webhook_data") + @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1) @patch.object(module.WebhookService, "generate_webhook_response") def test_debug_success( self, diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index 4fb735b033..a1dbc80b20 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -49,6 +49,17 @@ class _FakeSession: assert self._model_name is not None return self._mapping.get(self._model_name) + def get(self, model, ident): + return self._mapping.get(model.__name__) + + def scalar(self, stmt): + # Extract the model name from the select statement's column_descriptions + try: + name = stmt.column_descriptions[0]["entity"].__name__ + except (AttributeError, IndexError, KeyError): + return None + return self._mapping.get(name) + class _FakeDB: """Minimal db stub exposing engine and session.""" diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/unit_tests/controllers/web/test_site.py index 557bf93e9e..6e9d754c43 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/unit_tests/controllers/web/test_site.py @@ -50,7 +50,7 @@ class TestAppSiteApi: app.config["RESTX_MASK_HEADER"] = "X-Fields" mock_features.return_value = SimpleNamespace(can_replace_logo=False) site_obj = _site() - mock_db.session.query.return_value.where.return_value.first.return_value = site_obj + mock_db.session.scalar.return_value = site_obj tenant = _tenant() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) end_user = SimpleNamespace(id="eu-1") @@ -66,9 +66,9 @@ class TestAppSiteApi: @patch("controllers.web.site.db") def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): @@ -80,7 +80,7 @@ class TestAppSiteApi: app.config["RESTX_MASK_HEADER"] = "X-Fields" from models.account import TenantStatus - mock_db.session.query.return_value.where.return_value.first.return_value = _site() + mock_db.session.scalar.return_value = _site() tenant = SimpleNamespace( id="tenant-1", status=TenantStatus.ARCHIVE, diff --git a/api/tests/unit_tests/core/app/app_config/__init__.py b/api/tests/unit_tests/core/app/app_config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py new file mode 100644 index 0000000000..1c5b6ed944 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py @@ -0,0 +1,227 @@ +from unittest.mock import MagicMock + +import pytest + +# Module under test +from core.app.app_config.common import parameters_mapping + + +class TestGetParametersFromFeatureDict: + """Test suite for get_parameters_from_feature_dict""" + + @pytest.fixture + def mock_config(self, monkeypatch): + """Mock dify_config values""" + mock = MagicMock() + mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1 + mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2 + mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3 + mock.UPLOAD_FILE_SIZE_LIMIT = 4 + mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5 + + monkeypatch.setattr(parameters_mapping, "dify_config", mock) + return mock + + @pytest.fixture + def mock_default_file_limits(self, monkeypatch): + """Mock DEFAULT_FILE_NUMBER_LIMITS constant""" + monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99) + return 99 + + @pytest.fixture + def minimal_inputs(self): + return {}, [] + + @pytest.mark.parametrize( + ("feature_key", "expected_default"), + [ + ("suggested_questions", []), + ("suggested_questions_after_answer", {"enabled": False}), + ("speech_to_text", {"enabled": False}), + ("text_to_speech", {"enabled": False}), + ("retriever_resource", {"enabled": False}), + ("annotation_reply", {"enabled": False}), + ("more_like_this", {"enabled": False}), + ( + "sensitive_word_avoidance", + {"enabled": False, "type": "", "configs": []}, + ), + ], + ) + def test_defaults_when_key_missing( + self, + feature_key, + expected_default, + mock_config, + mock_default_file_limits, + ): + # Arrange + features = {} + user_input = [] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + assert result[feature_key] == expected_default + + def test_opening_statement_present(self, mock_config, mock_default_file_limits): + # Arrange + features = {"opening_statement": "Hello"} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] == "Hello" + + def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] is None + + def test_all_features_provided(self, mock_config, mock_default_file_limits): + # Arrange + features = { + "opening_statement": "Hi", + "suggested_questions": ["Q1"], + "suggested_questions_after_answer": {"enabled": True}, + "speech_to_text": {"enabled": True}, + "text_to_speech": {"enabled": True}, + "retriever_resource": {"enabled": True}, + "annotation_reply": {"enabled": True}, + "more_like_this": {"enabled": True}, + "sensitive_word_avoidance": { + "enabled": True, + "type": "strict", + "configs": ["a"], + }, + "file_upload": { + "image": { + "enabled": True, + "number_limits": 10, + "detail": "low", + "transfer_methods": ["local_file"], + } + }, + } + user_input = [{"name": "field1"}] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + for key in features: + assert result[key] == features[key] + assert result["user_input_form"] == user_input + + def test_file_upload_default_structure(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + file_upload = result["file_upload"] + assert file_upload["image"]["enabled"] is False + assert file_upload["image"]["number_limits"] == 99 + assert file_upload["image"]["detail"] == "high" + assert "remote_url" in file_upload["image"]["transfer_methods"] + assert "local_file" in file_upload["image"]["transfer_methods"] + + def test_system_parameters_from_config(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + system_params = result["system_parameters"] + assert system_params["image_file_size_limit"] == 1 + assert system_params["video_file_size_limit"] == 2 + assert system_params["audio_file_size_limit"] == 3 + assert system_params["file_size_limit"] == 4 + assert system_params["workflow_file_upload_limit"] == 5 + + @pytest.mark.parametrize( + ("features_dict", "user_input_form"), + [ + (None, []), + ([], []), + ("invalid", []), + ], + ) + def test_invalid_features_dict_type_raises(self, features_dict, user_input_form): + # Act & Assert + with pytest.raises(AttributeError): + parameters_mapping.get_parameters_from_feature_dict( + features_dict=features_dict, + user_input_form=user_input_form, + ) + + @pytest.mark.parametrize( + "user_input_form", + [None, "invalid", 123], + ) + def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input_form, + ) + + # Assert + assert result["user_input_form"] == user_input_form + + def test_empty_user_input_form(self, mock_config, mock_default_file_limits): + features = {} + user_input = [] + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + assert result["user_input_form"] == [] + + def test_feature_values_none(self, mock_config, mock_default_file_limits): + features = { + "suggested_questions": None, + "speech_to_text": None, + } + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + assert result["suggested_questions"] is None + assert result["speech_to_text"] is None diff --git a/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py new file mode 100644 index 0000000000..013ed0cbc4 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py @@ -0,0 +1,202 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.common.sensitive_word_avoidance.manager import ( + SensitiveWordAvoidanceConfigManager, +) + + +class TestSensitiveWordAvoidanceConfigManagerConvert: + """Tests for convert classmethod""" + + @pytest.mark.parametrize( + "config", + [ + {}, + {"sensitive_word_avoidance": None}, + {"sensitive_word_avoidance": {}}, + {"sensitive_word_avoidance": {"enabled": False}}, + ], + ) + def test_convert_returns_none_when_disabled_or_missing(self, config): + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result is None + + def test_convert_returns_entity_when_enabled(self, mocker): + # Arrange + mock_entity = MagicMock() + mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"key": "value"}, + } + } + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result == mock_entity + + def test_convert_enabled_without_type_or_config(self, mocker): + # Arrange + mock_entity = MagicMock() + patched = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = {"sensitive_word_avoidance": {"enabled": True}} + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + patched.assert_called_once_with(type=None, config={}) + assert result == mock_entity + + +class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults: + """Tests for validate_and_set_defaults classmethod""" + + @pytest.fixture + def base_config(self): + return {} + + def test_validate_sets_default_when_missing(self, base_config): + # Act + config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=base_config.copy() + ) + + # Assert + assert config["sensitive_word_avoidance"]["enabled"] is False + assert fields == ["sensitive_word_avoidance"] + + def test_validate_raises_when_not_dict(self): + config = {"sensitive_word_avoidance": "invalid"} + + with pytest.raises(ValueError, match="must be of dict type"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + @pytest.mark.parametrize( + "config", + [ + {"sensitive_word_avoidance": {"enabled": False}}, + {"sensitive_word_avoidance": {"enabled": None}}, + {"sensitive_word_avoidance": {}}, + ], + ) + def test_validate_disables_when_enabled_false_or_missing(self, config): + # Act + result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + assert result_config["sensitive_word_avoidance"]["enabled"] is False + + def test_validate_raises_when_enabled_true_without_type(self): + config = {"sensitive_word_avoidance": {"enabled": True}} + + with pytest.raises(ValueError, match="type is required"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_type_not_string(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": 123, + } + } + + with pytest.raises(ValueError, match="must be a string"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_config_not_dict(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": "invalid", + } + } + + with pytest.raises(ValueError, match="must be a dict"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_calls_moderation_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"}) + assert result_config["sensitive_word_avoidance"]["enabled"] is True + assert fields == ["sensitive_word_avoidance"] + + def test_validate_sets_empty_dict_when_config_none(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": None, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={}) + + def test_validate_only_structure_validate_skips_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config, only_structure_validate=True + ) + + # Assert + mock_validate.assert_not_called() diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py new file mode 100644 index 0000000000..992b580376 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager + + +class TestAgentConfigManagerConvert: + @pytest.fixture + def base_config(self): + return { + "agent_mode": { + "enabled": True, + "strategy": "cot", + "tools": [], + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "completion", + }, + } + + def test_convert_returns_none_when_agent_mode_missing(self): + config = {"model": {"provider": "openai", "name": "gpt-4"}} + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}]) + def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config): + config = base_config.copy() + config["agent_mode"] = agent_mode_value + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize( + ("strategy_input", "expected_enum"), + [ + ("function_call", "FUNCTION_CALLING"), + ("cot", "CHAIN_OF_THOUGHT"), + ("react", "CHAIN_OF_THOUGHT"), + ], + ) + def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": strategy_input, + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.strategy.name == expected_enum + + def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "openai" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "FUNCTION_CALLING" + + def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "anthropic" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "CHAIN_OF_THOUGHT" + + def test_convert_skips_disabled_tools(self, mocker, base_config): + # Patch AgentEntity to bypass pydantic validation + mock_agent_entity = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity", + return_value=MagicMock(), + ) + + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value={ + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "tool_parameters": {}, + "credential_id": None, + }, + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + { + "provider_type": "type1", + "provider_id": "id1", + "tool_name": "tool1", + "enabled": False, + }, + { + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "enabled": True, + "extra_key": "x", + }, + ], + } + + AgentConfigManager.convert(config) + + mock_validate.assert_called_once() + mock_agent_entity.assert_called_once() + + def test_convert_tool_requires_minimum_keys(self, mocker, base_config): + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value=MagicMock(), + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + {"a": 1, "b": 2}, # insufficient keys + ], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.tools == [] + mock_validate.assert_not_called() + + def test_convert_completion_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "completion" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_chat_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "chat" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_react_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "react_router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_max_iteration_default(self, base_config): + config = base_config.copy() + config["agent_mode"].pop("max_iteration", None) + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 10 + + def test_convert_custom_max_iteration(self, base_config): + config = base_config.copy() + config["agent_mode"]["max_iteration"] = 25 + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 25 + + def test_convert_missing_model_raises_key_error(self, base_config): + config = base_config.copy() + del config["model"] + + with pytest.raises(KeyError): + AgentConfigManager.convert(config) + + @pytest.mark.parametrize( + ("invalid_config", "should_raise"), + [ + (None, True), + (123, True), + ("", False), + ([], False), + ], + ) + def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise): + if should_raise: + with pytest.raises(TypeError): + AgentConfigManager.convert(invalid_config) # type: ignore + else: + result = AgentConfigManager.convert(invalid_config) # type: ignore + assert result is None diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py new file mode 100644 index 0000000000..a688e2a5c5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py @@ -0,0 +1,319 @@ +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def valid_uuid(): + return str(uuid.uuid4()) + + +@pytest.fixture +def base_config(valid_uuid): + return { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + + +@pytest.fixture +def mock_dataset_service(mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + +# ============================== +# convert tests +# ============================== + + +class TestDatasetConfigManagerConvert: + def test_convert_returns_none_when_no_datasets(self): + config = {"dataset_configs": {"datasets": {"datasets": []}}} + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_single_retrieval(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result is not None + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable == "query" + + def test_convert_single_with_metadata_configs(self, valid_uuid, mocker): + mock_retrieve_config = MagicMock() + mock_entity = MagicMock() + mock_entity.dataset_ids = [valid_uuid] + mock_entity.retrieve_config = mock_retrieve_config + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig", + return_value={"mock": "model"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition", + return_value={"mock": "condition"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity", + return_value=mock_retrieve_config, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity", + return_value=mock_entity, + ) + + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "metadata_filtering_mode": "manual", + "metadata_model_config": {"any": "value"}, + "metadata_filtering_conditions": {"any": "value"}, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config is mock_retrieve_config + + def test_convert_multiple_defaults(self, valid_uuid): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 4 + assert result.retrieve_config.score_threshold is None + assert result.retrieve_config.reranking_enabled is True + + def test_convert_agent_mode_disabled_tool(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": False}}], + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_dataset_configs_none(self): + config = {"dataset_configs": None} + with pytest.raises(TypeError): + DatasetConfigManager.convert(config) + + def test_convert_agent_mode_old_style_old_format(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable is None + + def test_convert_multiple_with_score_threshold(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "multiple", + "top_k": 5, + "score_threshold": 0.8, + "score_threshold_enabled": True, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 5 + assert result.retrieve_config.score_threshold == 0.8 + + @pytest.mark.parametrize( + "dataset_entry", + [ + {}, + {"invalid": {}}, + {"dataset": {"id": None, "enabled": True}}, + {"dataset": {"id": "", "enabled": False}}, + ], + ) + def test_convert_ignores_invalid_dataset_entries(self, dataset_entry): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": {"strategy": "router", "datasets": [dataset_entry]}, + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_agent_mode_old_style(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + + +# ============================== +# validate_and_set_defaults tests +# ============================== + + +class TestValidateAndSetDefaults: + def test_validate_sets_defaults(self): + config = {} + updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + assert "dataset_configs" in updated + assert updated["dataset_configs"]["retrieval_model"] == "single" + assert isinstance(fields, list) + + def test_validate_raises_when_dataset_configs_not_dict(self): + config = {"dataset_configs": "invalid"} + with pytest.raises(AttributeError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + + def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid): + config = { + "dataset_configs": { + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + } + with pytest.raises(ValueError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config) + + +# ============================== +# extract_dataset_config_for_legacy_compatibility tests +# ============================== + + +class TestExtractDatasetConfig: + def test_extract_sets_defaults(self): + config = {} + result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + assert "agent_mode" in result + assert result["agent_mode"]["enabled"] is False + assert result["agent_mode"]["tools"] == [] + + def test_extract_invalid_agent_mode_type(self): + config = {"agent_mode": "invalid"} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_enabled_type(self): + config = {"agent_mode": {"enabled": "yes"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_tools_type(self): + config = {"agent_mode": {"enabled": True, "tools": "invalid"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_uuid(self, mocker): + invalid_uuid = "not-a-uuid" + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_dataset_not_exists(self, valid_uuid, mocker): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + +# ============================== +# is_dataset_exists tests +# ============================== + + +class TestIsDatasetExists: + def test_dataset_exists_true(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "other" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py new file mode 100644 index 0000000000..aed1651511 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.entities.model_entities import ModelStatus +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey + + +class TestModelConfigConverter: + @pytest.fixture(autouse=True) + def patch_response_entity(self, mocker): + """ + Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation + and return a simple namespace object instead. + """ + + def _factory(**kwargs): + return SimpleNamespace(**kwargs) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity", + side_effect=_factory, + ) + + @pytest.fixture + def mock_app_config(self): + app_config = MagicMock() + app_config.tenant_id = "tenant_1" + + model_config = MagicMock() + model_config.provider = "openai" + model_config.model = "gpt-4" + model_config.parameters = {"temperature": 0.5} + model_config.mode = None + + app_config.model = model_config + return app_config + + @pytest.fixture + def mock_provider_bundle(self): + bundle = MagicMock() + + # configuration + configuration = MagicMock() + configuration.provider.provider = "openai" + configuration.get_current_credentials.return_value = {"api_key": "key"} + + provider_model = MagicMock() + provider_model.status = ModelStatus.ACTIVE + configuration.get_provider_model.return_value = provider_model + + bundle.configuration = configuration + + # model type instance + model_type_instance = MagicMock() + model_schema = MagicMock() + model_schema.model_properties = {} + model_type_instance.get_model_schema.return_value = model_schema + bundle.model_type_instance = model_type_instance + + return bundle + + @pytest.fixture + def patch_provider_manager(self, mocker, mock_provider_bundle): + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + return mock_manager + + # ============================= + # Positive Scenarios + # ============================= + + def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager): + result = ModelConfigConverter.convert(mock_app_config) + + assert result.provider == "openai" + assert result.model == "gpt-4" + assert result.mode == LLMMode.CHAT + assert result.parameters == {"temperature": 0.5} + assert result.stop == [] + + def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager): + mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]} + + result = ModelConfigConverter.convert(mock_app_config) + + assert result.parameters == {"temperature": 0.7} + assert result.stop == ["\n"] + + def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker): + mock_app_config.model.mode = None + + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: LLMMode.COMPLETION.value + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.COMPLETION + + def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: "invalid" + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.CHAT + + # ============================= + # Credential Errors + # ============================= + + def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_current_credentials.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(ProviderTokenNotInitError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Provider Model Errors + # ============================= + + def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_provider_model.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError), + (ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError), + (ModelStatus.QUOTA_EXCEEDED, QuotaExceededError), + ], + ) + def test_convert_provider_model_status_errors( + self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception + ): + mock_provider = MagicMock() + mock_provider.status = status + mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(expected_exception): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Schema Errors + # ============================= + + def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Edge Cases + # ============================= + + @pytest.mark.parametrize( + "parameters", + [ + {}, + {"stop": []}, + {"stop": ["END"], "max_tokens": 100}, + ], + ) + def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters): + mock_app_config.model.parameters = parameters.copy() + + result = ModelConfigConverter.convert(mock_app_config) + + if "stop" in parameters: + assert result.stop == parameters.get("stop") + expected_params = parameters.copy() + expected_params.pop("stop", None) + assert result.parameters == expected_params + else: + assert result.stop == [] + assert result.parameters == parameters diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py new file mode 100644 index 0000000000..e2ba276d8e --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -0,0 +1,230 @@ +from unittest.mock import MagicMock + +import pytest + +# Target +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def valid_completion_params(): + return {"temperature": 0.7, "stop": ["\n"]} + + +@pytest.fixture +def valid_model_list(): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {"mode": "chat"} + return [model] + + +@pytest.fixture +def provider_entities(): + provider = MagicMock() + provider.provider = "openai/gpt" + return [provider] + + +@pytest.fixture +def valid_config(): + return { + "model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}} + } + + +# ----------------------------- +# Test Class +# ----------------------------- + + +class TestModelConfigManager: + # ========================================================== + # convert + # ========================================================== + + def test_convert_success(self, valid_config): + result = ModelConfigManager.convert(valid_config) + + assert result.provider == "openai/gpt" + assert result.model == "gpt-4" + assert result.parameters == {"temperature": 0.5} + assert result.stop == ["END"] + + def test_convert_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.convert({}) + + def test_convert_without_stop(self): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}} + result = ModelConfigManager.convert(config) + assert result.stop == [] + assert result.parameters == {"temperature": 0.9} + + # ========================================================== + # validate_model_completion_params + # ========================================================== + + @pytest.mark.parametrize( + "invalid_cp", + [None, "string", 123, []], + ) + def test_validate_model_completion_params_invalid_type(self, invalid_cp): + with pytest.raises(ValueError, match="must be of object type"): + ModelConfigManager.validate_model_completion_params(invalid_cp) + + def test_validate_model_completion_params_default_stop(self): + cp = {"temperature": 0.2} + result = ModelConfigManager.validate_model_completion_params(cp) + assert result["stop"] == [] + + def test_validate_model_completion_params_invalid_stop_type(self): + cp = {"stop": "invalid"} + with pytest.raises(ValueError, match="must be of list type"): + ModelConfigManager.validate_model_completion_params(cp) + + def test_validate_model_completion_params_stop_length_exceeded(self): + cp = {"stop": [1, 2, 3, 4, 5]} + with pytest.raises(ValueError, match="less than 4"): + ModelConfigManager.validate_model_completion_params(cp) + + # ========================================================== + # validate_and_set_defaults + # ========================================================== + + def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) + + assert updated_config["model"]["mode"] == "chat" + assert keys == ["model"] + + def test_validate_and_set_defaults_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", {}) + + def test_validate_and_set_defaults_model_not_dict(self): + with pytest.raises(ValueError, match="object type"): + ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"}) + + def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): + config = {"model": {"name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): + config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + with pytest.raises(ValueError, match="model.name is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = [] + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {} + + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = [model] + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + assert updated_config["model"]["mode"] == "completion" + + def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} + + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + mock_factory.return_value.get_providers.return_value = provider_entities + + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + with pytest.raises(ValueError, match="completion_params is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list): + """ + Covers branch where provider does not contain '/' and + ModelProviderID conversion is triggered (line 64). + """ + config = { + "model": { + "provider": "openai", # no slash -> triggers conversion + "name": "gpt-4", + "completion_params": {}, + } + } + + # Mock ModelProviderID to return formatted provider + mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") + mock_provider_id.return_value = "openai/gpt" + + # Mock provider factory + mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") + provider_entity = MagicMock() + provider_entity.provider = "openai/gpt" + mock_factory.return_value.get_providers.return_value = [provider_entity] + + # Mock provider manager + mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") + mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + # Ensure conversion happened + mock_provider_id.assert_called_once_with("openai") + assert updated_config["model"]["provider"] == "openai/gpt" diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py new file mode 100644 index 0000000000..fd49072cd5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( + PromptTemplateConfigManager, +) + +# ----------------------------- +# Helpers +# ----------------------------- + + +class DummyEnumValue: + def __init__(self, value): + self.value = value + + +class DummyPromptType: + def __init__(self): + self.SIMPLE = "simple" + self.ADVANCED = "advanced" + + def value_of(self, value): + return value + + def __iter__(self): + return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + + +# ----------------------------- +# Convert Tests +# ----------------------------- + + +class TestPromptTemplateConfigManagerConvert: + def test_convert_missing_prompt_type_raises(self): + with pytest.raises(ValueError, match="prompt_type is required"): + PromptTemplateConfigManager.convert({}) + + def test_convert_simple_prompt(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mock_prompt_entity_cls.return_value = "simple_entity" + + config = {"prompt_type": "simple", "pre_prompt": "hello"} + + result = PromptTemplateConfigManager.convert(config) + + assert result == "simple_entity" + mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello") + + def test_convert_advanced_chat_valid(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of", + return_value="role_enum", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity", + return_value="chat_msg", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity", + return_value="chat_template", + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]}, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + @pytest.mark.parametrize( + "message", + [ + {"text": 123, "role": "user"}, + {"text": "hi", "role": 123}, + ], + ) + def test_convert_advanced_invalid_message_fields(self, mocker, message): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [message]}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.convert(config) + + def test_convert_advanced_completion_with_roles(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity", + return_value="completion_template", + ) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "complete"}, + "conversation_histories_role": { + "user_prefix": "U", + "assistant_prefix": "A", + }, + }, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + +# ----------------------------- +# validate_and_set_defaults +# ----------------------------- + + +class TestValidateAndSetDefaults: + def setup_method(self): + self.valid_model = {"mode": "chat"} + + def _patch_prompt_type(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + return mock_prompt_entity_cls + + def test_default_prompt_type_set(self, mocker): + self._patch_prompt_type(mocker) + + config = {"model": self.valid_model} + + result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + assert result["prompt_type"] == "simple" + assert isinstance(keys, list) + + def test_invalid_prompt_type_raises(self, mocker): + class InvalidEnum(DummyPromptType): + def __iter__(self): + return iter([DummyEnumValue("valid")]) + + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = InvalidEnum() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = {"prompt_type": "invalid", "model": self.valid_model} + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_invalid_chat_prompt_config_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "chat_prompt_config": "invalid", + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_simple_mode_invalid_pre_prompt_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "pre_prompt": 123, + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_requires_one_config(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {}, + "completion_prompt_config": {}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_invalid_model_mode(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": []}, + "model": {"mode": "invalid"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_chat_prompt_length_exceeds(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{}] * 11}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_completion_prefix_defaults_set_when_empty(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "hi"}, + "conversation_histories_role": { + "user_prefix": "", + "assistant_prefix": "", + }, + }, + "model": {"mode": "completion"}, + } + + updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config) + + roles = updated["completion_prompt_config"]["conversation_histories_role"] + assert roles["user_prefix"] == "Human" + assert roles["assistant_prefix"] == "Assistant" + + +# ----------------------------- +# validate_post_prompt +# ----------------------------- + + +class TestValidatePostPrompt: + @pytest.mark.parametrize("value", [None, ""]) + def test_post_prompt_defaults(self, value): + config = {"post_prompt": value} + result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) + assert result["post_prompt"] == "" + + def test_post_prompt_invalid_type(self): + config = {"post_prompt": 123} + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py new file mode 100644 index 0000000000..5def29b741 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -0,0 +1,286 @@ +import pytest + +from core.app.app_config.easy_ui_based_app.variables.manager import ( + BasicVariablesConfigManager, +) +from dify_graph.variables.input_entities import VariableEntityType + + +class TestBasicVariablesConfigManagerConvert: + def test_convert_empty_config(self): + config = {} + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + def test_convert_external_data_tools_enabled_and_disabled(self, mocker): + config = { + "external_data_tools": [ + {"enabled": False}, + { + "enabled": True, + "variable": "ext_var", + "type": "tool_type", + "config": {"k": "v"}, + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert len(external) == 1 + assert external[0].variable == "ext_var" + assert external[0].type == "tool_type" + + def test_convert_user_input_form_variable_types(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "variable": "name", + "label": "Name", + "description": "desc", + "required": True, + "max_length": 50, + } + }, + { + VariableEntityType.SELECT: { + "variable": "choice", + "label": "Choice", + "options": ["a", "b"], + } + }, + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + "config": {"x": 1}, + } + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert len(variables) == 2 + assert len(external) == 1 + + def test_convert_external_data_tool_without_config_skipped(self): + config = { + "user_input_form": [ + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + } + } + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + +class TestValidateVariablesAndSetDefaults: + def test_validate_sets_empty_user_input_form_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"] == [] + assert "user_input_form" in keys + + def test_validate_user_input_form_not_list_raises(self): + config = {"user_input_form": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_invalid_key_raises(self): + config = {"user_input_form": [{"invalid": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_label_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_label_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_variable_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_variable_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + @pytest.mark.parametrize( + "variable_name", + ["1invalid", "invalid space", "", None], + ) + def test_validate_variable_invalid_pattern_raises(self, variable_name): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": variable_name, + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_required_default_and_type(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + } + } + ] + } + + updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False + + def test_validate_required_not_bool_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + "required": "yes", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_default_not_in_options_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": ["a", "b"], + "default": "c", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_not_list_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": "not_list", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + +class TestValidateExternalDataToolsAndSetDefaults: + def test_validate_sets_empty_external_data_tools_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + assert updated["external_data_tools"] == [] + assert "external_data_tools" in keys + + def test_validate_external_data_tools_not_list_raises(self): + config = {"external_data_tools": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_disabled_tool_skipped(self, mocker): + config = {"external_data_tools": [{"enabled": False}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + spy.assert_not_called() + assert updated["external_data_tools"][0]["enabled"] is False + + def test_validate_enabled_tool_missing_type_raises(self): + config = {"external_data_tools": [{"enabled": True, "config": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_enabled_tool_calls_factory(self, mocker): + config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config) + + spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1}) + + +class TestValidateAndSetDefaultsIntegration: + def test_validate_and_set_defaults_calls_both(self, mocker): + config = {} + + spy_var = mocker.patch.object( + BasicVariablesConfigManager, + "validate_variables_and_set_defaults", + return_value=(config, ["user_input_form"]), + ) + spy_ext = mocker.patch.object( + BasicVariablesConfigManager, + "validate_external_data_tools_and_set_defaults", + return_value=(config, ["external_data_tools"]), + ) + + updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config) + + spy_var.assert_called_once() + spy_ext.assert_called_once() + assert "user_input_form" in keys + assert "external_data_tools" in keys + assert updated == config diff --git a/api/tests/unit_tests/core/app/app_config/features/__init__.py b/api/tests/unit_tests/core/app/app_config/features/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py new file mode 100644 index 0000000000..dd00c3defc --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -0,0 +1,115 @@ +import pytest + +from core.app.app_config.entities import TextToSpeechEntity +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager + + +class TestAdditionalFeatureManagers: + def test_opening_statement_validate_defaults(self): + config, keys = OpeningStatementConfigManager.validate_and_set_defaults({}) + assert config["opening_statement"] == "" + assert config["suggested_questions"] == [] + assert set(keys) == {"opening_statement", "suggested_questions"} + + def test_opening_statement_validate_types(self): + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123}) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": "bad"} + ) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": [1]} + ) + + def test_opening_statement_convert(self): + opening, questions = OpeningStatementConfigManager.convert( + {"opening_statement": "hello", "suggested_questions": ["q1"]} + ) + assert opening == "hello" + assert questions == ["q1"] + + def test_retrieval_resource_validate(self): + config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({}) + assert config["retriever_resource"]["enabled"] is False + assert keys == ["retriever_resource"] + + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"}) + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}}) + + def test_retrieval_resource_convert(self): + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False + + def test_speech_to_text_validate_and_convert(self): + config, keys = SpeechToTextConfigManager.validate_and_set_defaults({}) + assert config["speech_to_text"]["enabled"] is False + assert keys == ["speech_to_text"] + + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"}) + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}}) + + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False + + def test_suggested_questions_after_answer_validate_and_convert(self): + config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({}) + assert config["suggested_questions_after_answer"]["enabled"] is False + assert keys == ["suggested_questions_after_answer"] + + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": "bad"} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": "yes"}} + ) + + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) + is True + ) + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}}) + is False + ) + + def test_text_to_speech_validate_and_convert(self): + config, keys = TextToSpeechConfigManager.validate_and_set_defaults({}) + assert config["text_to_speech"]["enabled"] is False + assert keys == ["text_to_speech"] + + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"}) + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}}) + + result = TextToSpeechConfigManager.convert( + {"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}} + ) + assert isinstance(result, TextToSpeechEntity) + assert result.voice == "v" + assert result.language == "en" + + def test_more_like_this_convert_and_validate(self): + config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({}) + assert config["more_like_this"]["enabled"] is False + assert keys == ["more_like_this"] + + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False + with pytest.raises(ValueError): + MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"}) diff --git a/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py new file mode 100644 index 0000000000..e99852cf76 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py @@ -0,0 +1,180 @@ +from collections import UserDict +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager + + +class TestBaseAppConfigManager: + @pytest.fixture + def mock_config_dict(self): + return {"key": "value", "another": 123} + + @pytest.fixture + def mock_app_additional_features(self, mocker): + mock_instance = MagicMock() + mocker.patch( + "core.app.app_config.base_app_config_manager.AppAdditionalFeatures", + return_value=mock_instance, + ) + return mock_instance + + @pytest.fixture + def mock_managers(self, mocker): + retrieval = mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + return_value="retrieval_result", + ) + file_upload = mocker.patch( + "core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert", + return_value="file_upload_result", + ) + opening_statement = mocker.patch( + "core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert", + return_value=("opening_result", "suggested_result"), + ) + suggested_after = mocker.patch( + "core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert", + return_value="suggested_after_result", + ) + more_like_this = mocker.patch( + "core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert", + return_value="more_like_this_result", + ) + speech_to_text = mocker.patch( + "core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert", + return_value="speech_to_text_result", + ) + text_to_speech = mocker.patch( + "core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert", + return_value="text_to_speech_result", + ) + + return { + "retrieval": retrieval, + "file_upload": file_upload, + "opening_statement": opening_statement, + "suggested_after": suggested_after, + "more_like_this": more_like_this, + "speech_to_text": speech_to_text, + "text_to_speech": text_to_speech, + } + + @pytest.mark.parametrize( + ("app_mode", "expected_is_vision"), + [ + ("CHAT", True), + ("COMPLETION", True), + ("AGENT_CHAT", True), + ("OTHER", False), + ], + ) + def test_convert_features_all_modes( + self, + mocker, + mock_config_dict, + mock_app_additional_features, + mock_managers, + app_mode, + expected_is_vision, + ): + # Arrange + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode) + + # Assert + assert result == mock_app_additional_features + mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["file_upload"].assert_called_once() + _, kwargs = mock_managers["file_upload"].call_args + assert kwargs["config"] == dict(mock_config_dict.items()) + assert kwargs["is_vision"] is expected_is_vision + + mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items())) + + def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + empty_config = {} + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(empty_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called + + @pytest.mark.parametrize( + "invalid_config", + [ + None, + "string", + 123, + 12.34, + [], + ], + ) + def test_convert_features_invalid_config_raises(self, invalid_config): + # Act & Assert + with pytest.raises((TypeError, AttributeError)): + BaseAppConfigManager.convert_features(invalid_config, "CHAT") + + def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict): + # Arrange + mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + side_effect=RuntimeError("manager failure"), + ) + + # Act & Assert + with pytest.raises(RuntimeError): + BaseAppConfigManager.convert_features(mock_config_dict, "CHAT") + + def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + class CustomMapping(UserDict): + pass + + custom_config = CustomMapping({"a": 1}) + + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(custom_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py new file mode 100644 index 0000000000..eafdf99c16 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -0,0 +1,43 @@ +import pytest + +from core.app.app_config.entities import ( + DatasetRetrieveConfigEntity, + PromptTemplateEntity, +) +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType + + +class TestAppConfigEntities: + def test_variable_entity_coerces_none_description_and_options(self): + entity = VariableEntity( + variable="query", + label="Query", + description=None, + type=VariableEntityType.TEXT_INPUT, + options=None, + ) + + assert entity.description == "" + assert entity.options == [] + + def test_variable_entity_rejects_invalid_json_schema(self): + with pytest.raises(ValueError): + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + json_schema={"type": "string", "minLength": "bad"}, + ) + + def test_prompt_template_value_of(self): + assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE + with pytest.raises(ValueError): + PromptTemplateEntity.PromptType.value_of("missing") + + def test_dataset_retrieve_strategy_value_of(self): + assert ( + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single") + == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + ) + with pytest.raises(ValueError): + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing") diff --git a/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py new file mode 100644 index 0000000000..fa128aca87 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py @@ -0,0 +1,222 @@ +import pytest + +from core.app.app_config.workflow_ui_based_app.variables.manager import ( + WorkflowVariablesConfigManager, +) + +# ============================= +# Fixtures +# ============================= + + +@pytest.fixture +def mock_workflow(mocker): + workflow = mocker.MagicMock() + workflow.graph_dict = {"nodes": []} + return workflow + + +@pytest.fixture +def mock_variable_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity") + + +@pytest.fixture +def mock_rag_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity") + + +# ============================= +# Test Convert (user_input_form) +# ============================= + + +class TestWorkflowVariablesConfigManagerConvert: + def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity): + # Arrange + input_variables = [{"name": "var1"}, {"name": "var2"}] + mock_workflow.user_input_form.return_value = input_variables + mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [{"validated": v} for v in input_variables] + assert mock_variable_entity.model_validate.call_count == 2 + + def test_convert_empty_list(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [] + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [] + mock_variable_entity.model_validate.assert_not_called() + + def test_convert_none_returned_raises(self, mock_workflow): + # Arrange + mock_workflow.user_input_form.return_value = None + + # Act & Assert + with pytest.raises(TypeError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [{"invalid": "data"}] + mock_variable_entity.model_validate.side_effect = ValueError("validation error") + + # Act & Assert + with pytest.raises(ValueError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + +# ============================= +# Test convert_rag_pipeline_variable +# ============================= + + +class TestWorkflowVariablesConfigManagerConvertRag: + def test_no_rag_pipeline_variables(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = [] + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_rag_pipeline_none(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = None + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}] + + def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + @pytest.mark.parametrize( + ("belong_to_node_id", "expected_count"), + [ + ("node1", 1), + ("shared", 1), + ("other_node", 0), + ], + ) + def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": belong_to_node_id}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == expected_count + + def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + + def test_validation_error_propagates(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed") + + # Act & Assert + with pytest.raises(RuntimeError): + WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py index 02a1e04c98..e861a0c684 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py @@ -44,11 +44,22 @@ class TestAgentChatAppGenerateResponseConverterBlocking: metadata={ "retriever_resources": [ { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", "segment_id": "s1", "position": 1, + "data_source_type": "file", "document_name": "doc", "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", "content": "content", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], } ], "annotation_reply": {"id": "a"}, @@ -107,11 +118,22 @@ class TestAgentChatAppGenerateResponseConverterStream: metadata={ "retriever_resources": [ { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", "segment_id": "s1", "position": 1, + "data_source_type": "file", "document_name": "doc", "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", "content": "content", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], "summary": "summary", "extra": "ignored", } @@ -151,11 +173,22 @@ class TestAgentChatAppGenerateResponseConverterStream: assert "usage" not in metadata assert metadata["retriever_resources"] == [ { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", "segment_id": "s1", "position": 1, + "data_source_type": "file", "document_name": "doc", "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", "content": "content", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], "summary": "summary", } ] diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index aba7dfff8c..374af5ddc4 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun import uuid from collections.abc import Mapping from dataclasses import dataclass +from datetime import UTC, datetime from typing import Any from unittest.mock import Mock @@ -234,6 +235,50 @@ class TestWorkflowResponseConverter: assert response.data.process_data == {} assert response.data.process_data_truncated is False + def test_workflow_node_finish_response_prefers_event_finished_at( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Finished timestamps should come from the event, not delayed queue processing time.""" + converter = self.create_workflow_response_converter() + start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None) + delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None) + + monkeypatch.setattr( + "core.app.apps.common.workflow_response_converter.naive_utc_now", + lambda: delayed_processing_time, + ) + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) + + event = QueueNodeSucceededEvent( + node_id="test-node-id", + node_type=BuiltinNodeTypes.CODE, + node_execution_id="node-exec-1", + start_at=start_at, + finished_at=finished_at, + in_iteration_id=None, + in_loop_id=None, + inputs={}, + process_data={}, + outputs={}, + execution_metadata={}, + ) + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + ) + + assert response is not None + assert response.data.elapsed_time == 2.0 + assert response.data.finished_at == int(finished_at.timestamp()) + def test_workflow_node_retry_response_uses_truncated_process_data(self): """Test that node retry response uses get_response_process_data().""" converter = self.create_workflow_response_converter() diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py index cf473dfbeb..0136dbf5ad 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py @@ -38,11 +38,22 @@ class TestCompletionAppGenerateResponseConverter: metadata = { "retriever_resources": [ { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", "segment_id": "s", "position": 1, + "data_source_type": "file", "document_name": "doc", "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", "content": "c", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], "summary": "sum", "extra": "x", } @@ -66,7 +77,12 @@ class TestCompletionAppGenerateResponseConverter: assert "annotation_reply" not in result["metadata"] assert "usage" not in result["metadata"] + assert result["metadata"]["retriever_resources"][0]["dataset_id"] == "dataset-1" + assert result["metadata"]["retriever_resources"][0]["document_id"] == "document-1" assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s" + assert result["metadata"]["retriever_resources"][0]["data_source_type"] == "file" + assert result["metadata"]["retriever_resources"][0]["segment_position"] == 3 + assert result["metadata"]["retriever_resources"][0]["index_node_hash"] == "abc1234" assert "extra" not in result["metadata"]["retriever_resources"][0] def test_convert_blocking_simple_response_metadata_not_dict(self): diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py index a25e3ec3f5..f48a7fb38e 100644 --- a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py @@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.task_pipeline import message_cycle_manager from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from models.enums import ConversationFromSource from models.model import AppMode, Conversation, Message @@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation(): system_instruction_tokens=0, status="normal", invoke_from=InvokeFrom.WEB_APP.value, - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id="user-id", from_account_id=None, ) diff --git a/api/tests/unit_tests/core/app/entities/test_queue_entities.py b/api/tests/unit_tests/core/app/entities/test_queue_entities.py new file mode 100644 index 0000000000..7c21b00966 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_queue_entities.py @@ -0,0 +1,12 @@ +from core.app.entities.queue_entities import QueueStopEvent + + +class TestQueueEntities: + def test_get_stop_reason_for_known_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + assert event.get_stop_reason() == "Stopped by user." + + def test_get_stop_reason_for_unknown_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + event.stopped_by = "unknown" + assert event.get_stop_reason() == "Stopped by unknown reason." diff --git a/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py new file mode 100644 index 0000000000..1e0ef6d6d6 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py @@ -0,0 +1,17 @@ +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity + + +class TestRagPipelineInvokeEntity: + def test_defaults_and_fields(self): + entity = RagPipelineInvokeEntity( + pipeline_id="pipe-1", + application_generate_entity={"foo": "bar"}, + user_id="user-1", + tenant_id="tenant-1", + workflow_id="workflow-1", + streaming=True, + ) + + assert entity.workflow_execution_id is None + assert entity.workflow_thread_pool_id is None + assert entity.streaming is True diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py new file mode 100644 index 0000000000..8ecab3199c --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -0,0 +1,78 @@ +from core.app.entities.task_entities import ( + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + StreamEvent, +) +from dify_graph.enums import WorkflowNodeExecutionStatus + + +class TestTaskEntities: + def test_node_start_to_ignore_detail_dict(self): + data = NodeStartStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + created_at=1, + ) + response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_STARTED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["extras"] == {} + + def test_node_finish_to_ignore_detail_dict(self): + data = NodeFinishStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ) + response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_FINISHED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["outputs"] is None + assert payload["data"]["files"] == [] + + def test_node_retry_to_ignore_detail_dict(self): + data = NodeRetryStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.RETRY, + elapsed_time=0.1, + created_at=1, + finished_at=2, + retry_index=2, + ) + response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_RETRY.value + assert payload["data"]["retry_index"] == 2 + assert payload["data"]["outputs"] is None diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c72..538b130cac 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" diff --git a/api/tests/unit_tests/core/app/features/test_annotation_reply.py b/api/tests/unit_tests/core/app/features/test_annotation_reply.py new file mode 100644 index 0000000000..e721a77079 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_annotation_reply.py @@ -0,0 +1,163 @@ +import logging +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature + + +class TestAnnotationReplyFeature: + def test_query_returns_none_when_setting_missing(self): + feature = AnnotationReplyFeature() + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = None + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_none_when_binding_missing(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace(collection_binding_detail=None) + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = annotation_setting + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_annotation_and_records_history_for_api(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result == annotation + mock_annotation_service.add_annotation_history.assert_called_once() + _, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "api" + assert score == 0.8 + + def test_query_returns_annotation_and_records_history_for_console(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=0.5, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=None, + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.EXPLORE, + ) + + assert result == annotation + _, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "console" + + def test_query_logs_and_returns_none_on_exception(self, caplog): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1") + mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom") + + with caplog.at_level(logging.WARNING): + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + assert "Query annotation failed" in caplog.text diff --git a/api/tests/unit_tests/core/app/features/test_hosting_moderation.py b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py new file mode 100644 index 0000000000..01194c16f5 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py @@ -0,0 +1,30 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature + + +class TestHostingModerationFeature: + def test_check_aggregates_text_and_calls_moderation(self): + application_generate_entity = Mock() + application_generate_entity.model_conf = {"model": "mock"} + application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1") + + prompt_messages = [ + SimpleNamespace(content="hello"), + SimpleNamespace(content=123), + SimpleNamespace(content="world"), + ] + + with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check: + mock_check.return_value = True + + feature = HostingModerationFeature() + result = feature.check(application_generate_entity, prompt_messages) + + assert result is True + mock_check.assert_called_once_with( + tenant_id="tenant-1", + model_config=application_generate_entity.model_conf, + text="hello\nworld\n", + ) diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py new file mode 100644 index 0000000000..c6d820dbc9 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -0,0 +1,19 @@ +from core.app.layers.suspend_layer import SuspendLayer +from dify_graph.graph_events.graph import GraphRunPausedEvent + + +class TestSuspendLayer: + def test_on_event_accepts_paused_event(self): + layer = SuspendLayer() + assert layer.is_paused() is False + layer.on_graph_start() + assert layer.is_paused() is False + layer.on_event(GraphRunPausedEvent()) + assert layer.is_paused() is True + + def test_on_event_ignores_other_events(self): + layer = SuspendLayer() + layer.on_graph_start() + initial_state = layer.is_paused() + layer.on_event(object()) + assert layer.is_paused() is initial_state diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py new file mode 100644 index 0000000000..c87eec1508 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -0,0 +1,98 @@ +from unittest.mock import Mock, patch + +from core.app.layers.timeslice_layer import TimeSliceLayer +from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import SchedulerCommand + + +class TestTimeSliceLayer: + def test_init_starts_scheduler_when_not_running(self): + scheduler = Mock() + scheduler.running = False + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + _ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock())) + + scheduler.start.assert_called_once() + + def test_on_graph_start_adds_job_for_time_slice(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid, + ): + mock_uuid.return_value.hex = "job-1" + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.on_graph_start() + + assert layer.schedule_id == "job-1" + scheduler.add_job.assert_called_once() + + def test_on_graph_end_removes_job(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.schedule_id = "job-1" + layer.on_graph_end(None) + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_removes_when_stopped(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.stopped = True + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_handles_resource_limit_without_command_channel(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.logger") as mock_logger, + ): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + mock_logger.exception.assert_called_once() + + def test_checker_job_sends_pause_command(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.command_channel = Mock() + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + layer.command_channel.send_command.assert_called_once() + sent_command = layer.command_channel.send_command.call_args[0][0] + assert isinstance(sent_command, GraphEngineCommand) + assert sent_command.command_type == CommandType.PAUSE diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py new file mode 100644 index 0000000000..f9755061d6 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -0,0 +1,106 @@ +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.layers.trigger_post_layer import TriggerPostLayer +from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent +from models.enums import WorkflowTriggerStatus + + +class TestTriggerPostLayer: + def test_on_event_updates_trigger_log(self): + trigger_log = SimpleNamespace( + status=None, + workflow_run_id=None, + outputs=None, + elapsed_time=None, + total_tokens=None, + finished_at=None, + ) + runtime_state = SimpleNamespace( + outputs={"answer": "ok"}, + system_variable=SimpleNamespace(workflow_execution_id="run-1"), + total_tokens=12, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime, + ): + mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC) + + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = trigger_log + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunSucceededEvent()) + + assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED + assert trigger_log.workflow_run_id == "run-1" + assert trigger_log.outputs is not None + assert trigger_log.elapsed_time is not None + assert trigger_log.total_tokens == 12 + assert trigger_log.finished_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + + def test_on_event_handles_missing_trigger_log(self): + runtime_state = SimpleNamespace( + outputs={}, + system_variable=SimpleNamespace(workflow_execution_id="run-1"), + total_tokens=0, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.logger") as mock_logger, + ): + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = None + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="missing", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunFailedEvent(error="boom")) + + mock_logger.exception.assert_called_once() + session.commit.assert_not_called() + + def test_on_event_ignores_non_status_events(self): + runtime_state = SimpleNamespace( + outputs={}, + system_variable=SimpleNamespace(workflow_execution_id="run-1"), + total_tokens=0, + ) + + with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory: + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(Mock()) + + mock_session_factory.create_session.assert_not_called() diff --git a/api/tests/unit_tests/core/app/task_pipeline/__init__.py b/api/tests/unit_tests/core/app/task_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py new file mode 100644 index 0000000000..e070eb06fd --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.entities.queue_entities import QueueErrorEvent +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.errors.error import QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from models.enums import MessageStatus + + +class TestBasedGenerateTaskPipeline: + @pytest.fixture + def pipeline(self): + app_config = SimpleNamespace( + tenant_id="tenant-1", + app_id="app-1", + sensitive_word_avoidance=None, + ) + app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config) + return BasedGenerateTaskPipeline( + application_generate_entity=app_generate_entity, + queue_manager=Mock(), + stream=True, + ) + + def test_error_to_desc_quota_exceeded(self, pipeline): + message = pipeline._error_to_desc(QuotaExceededError()) + assert "quota" in message.lower() + + def test_handle_error_wraps_invoke_authorization(self, pipeline): + event = QueueErrorEvent(error=InvokeAuthorizationError()) + err = pipeline.handle_error(event=event) + assert isinstance(err, InvokeAuthorizationError) + assert str(err) == "Incorrect API key provided" + + def test_handle_error_preserves_invoke_error(self, pipeline): + event = QueueErrorEvent(error=InvokeError("bad")) + err = pipeline.handle_error(event=event) + assert err is event.error + + def test_handle_error_updates_message_when_found(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + message = SimpleNamespace(status=MessageStatus.NORMAL, error=None) + session = Mock() + session.scalar.return_value = message + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + assert message.status == MessageStatus.ERROR + assert message.error == "oops" + + def test_handle_error_returns_err_when_message_missing(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + session = Mock() + session.scalar.return_value = None + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + + def test_error_to_stream_response_and_ping(self, pipeline): + error_response = pipeline.error_to_stream_response(ValueError("boom")) + ping_response = pipeline.ping_stream_response() + + assert error_response.task_id == "task-1" + assert ping_response.task_id == "task-1" + + def test_handle_output_moderation_when_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("filtered", True) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result == "filtered" + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None + + def test_handle_output_moderation_when_not_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("safe", False) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result is None + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py new file mode 100644 index 0000000000..155e6f2c73 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -0,0 +1,1228 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + ChatbotAppStreamResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AudioTrunk +from dify_graph.file.enums import FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent +from models.model import AppMode + + +class _DummyModelConf: + def __init__(self) -> None: + self.model = "mock" + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock", model="mock"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_entity(entity_cls, app_mode: AppMode): + app_config = _make_app_config(app_mode) + return entity_cls.model_construct( + task_id="task", + app_config=app_config, + model_conf=_DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +class TestEasyUiBasedGenerateTaskPipeline: + def test_to_blocking_response_chat(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_to_blocking_response_completion(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_listen_audio_msg_returns_none_when_no_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_process_stream_response_handles_chunks_and_end(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[TextPromptMessageContent(data="hi"), TextPromptMessageContent(data="yo")] + ), + ), + ) + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + events = [ + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + SimpleNamespace(event=QueueMessageReplaceEvent(text="replace", reason="output_moderation")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueMessageEndEvent(llm_result=llm_result)), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: "chunk" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline.handle_output_moderation_when_task_finished = lambda completion: None + pipeline._message_end_to_stream_response = lambda: "end" + pipeline._save_message = lambda **kwargs: None + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "chunk" in responses + assert "replace" in responses + assert any(isinstance(item, PingStreamResponse) for item in responses) + assert responses[-1] == "end" + + def test_handle_output_moderation_chunk_directs_output(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline.output_moderation_handler = _Moderation() + pipeline.queue_manager.publish = lambda event, publish_from: events.append(event) + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is True + assert any(isinstance(event, QueueLLMChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_stop_updates_usage(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + class _ModelType: + def calc_response_usage(self, model, credentials, prompt_tokens, completion_tokens): + return LLMUsage.from_metadata( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + ) + + class _ModelConf: + def __init__(self) -> None: + self.model = "mock" + self.credentials = {} + self.provider_model_bundle = SimpleNamespace(model_type_instance=_ModelType()) + + app_config = _make_app_config(AppMode.CHAT) + application_generate_entity = ChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + model_conf=_ModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content="answer") + + calls: list[int] = [] + + class _FakeModelInstance: + def __init__(self, provider_model_bundle, model): + pass + + def get_llm_num_tokens(self, messages): + calls.append(1) + return 10 if len(calls) == 1 else 5 + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.ModelInstance", + _FakeModelInstance, + ) + + pipeline._handle_stop(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) + + assert pipeline._task_state.llm_result.usage.prompt_tokens == 10 + assert pipeline._task_state.llm_result.usage.completion_tokens == 5 + + def test_record_files_builds_file_payloads(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + message_files = [ + SimpleNamespace( + id="mf-1", + message_id="msg", + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/a.png", + upload_file_id=None, + type="image", + ), + SimpleNamespace( + id="mf-2", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-1", + type="image", + ), + SimpleNamespace( + id="mf-3", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/file.bin", + upload_file_id=None, + type="file", + ), + ] + upload_files = [ + SimpleNamespace( + id="upload-1", + name="local.png", + mime_type="image/png", + size=123, + extension="png", + ) + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else upload_files) + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "signed-url", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "signed-tool", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files + assert len(files) == 3 + + def test_process_stream_response_handles_annotation_and_error(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="agent"), + ), + ) + + events = [ + SimpleNamespace(event=QueueAnnotationReplyEvent(message_annotation_id="ann")), + SimpleNamespace(event=QueueAgentThoughtEvent(agent_thought_id="thought")), + SimpleNamespace(event=QueueMessageFileEvent(message_file_id="file")), + SimpleNamespace(event=QueueAgentMessageEvent(chunk=agent_chunk)), + SimpleNamespace(event=QueueErrorEvent(error=ValueError("boom"))), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.handle_annotation_reply = lambda event: SimpleNamespace(content="annotated") + pipeline._agent_thought_to_stream_response = lambda event: "thought" + pipeline._message_cycle_manager.message_file_to_stream_response = lambda event: "file" + pipeline._agent_message_to_stream_response = lambda **kwargs: "agent" + pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline.error_to_stream_response = lambda err: err + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "thought" in responses + assert "file" in responses + assert "agent" in responses + assert isinstance(responses[-1], ValueError) + assert pipeline._task_state.llm_result.message.content == "annotatedagent" + + def test_agent_thought_to_stream_response_returns_payload(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_thought = SimpleNamespace( + id="thought", + position=1, + thought="t", + observation="o", + tool="tool", + tool_labels={}, + tool_input="input", + files=[], + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return agent_thought + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="thought")) + + assert response is not None + assert response.id == "thought" + + def test_process_routes_to_stream_and_starts_conversation_name_generation(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock(return_value=object()) + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_stream_response = lambda generator: "streamed" + + result = pipeline.process() + + assert result == "streamed" + pipeline._message_cycle_manager.generate_conversation_name.assert_called_once_with( + conversation_id="conv", query="hello" + ) + + def test_process_routes_to_blocking_for_completion_mode(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock() + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_blocking_response = lambda generator: "blocking" + + result = pipeline.process() + + assert result == "blocking" + pipeline._message_cycle_manager.generate_conversation_name.assert_not_called() + + def test_to_blocking_response_raises_error_stream_exception(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("stream error")) + + with pytest.raises(ValueError, match="stream error"): + pipeline._to_blocking_response(_gen()) + + def test_to_blocking_response_raises_when_generator_ends_without_message_end(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(RuntimeError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_gen()) + + def test_to_stream_response_wraps_completion_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, CompletionAppStreamResponse) + assert response.message_id == "msg" + + def test_to_stream_response_wraps_chat_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, ChatbotAppStreamResponse) + assert response.conversation_id == "conv" + + def test_listen_audio_msg_returns_audio_response_for_non_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("responding", "abc")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + assert response.audio == "abc" + + def test_listen_audio_msg_returns_none_for_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("finish", "abc")) + + assert pipeline._listen_audio_msg(publisher=publisher, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses == ["payload"] + + def test_wrapper_process_stream_response_with_tts_publisher(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def check_and_get_audio(self): + return AudioTrunk("finish", "") + + inline_audio = MessageAudioStreamResponse(task_id="task", audio="inline") + audio_calls = iter([inline_audio, None]) + pipeline._listen_audio_msg = lambda publisher, task_id: next(audio_calls) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses[0] == inline_audio + assert responses[1] == "payload" + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_wrapper_process_stream_response_timeout_yields_audio_chunk(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def __init__(self): + self._events = iter([None, AudioTrunk("responding", "later"), AudioTrunk("finish", "")]) + + def check_and_get_audio(self): + return next(self._events) + + clock = {"value": 0.0} + + def _fake_time(): + clock["value"] += 0.1 + return clock["value"] + + pipeline._process_stream_response = lambda publisher, trace_manager: iter([]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.time", _fake_time) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.sleep", lambda _: None) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_process_stream_response_handles_stop_event_and_output_replacement(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._task_state.llm_result.message.content = "raw answer" + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_stop = Mock() + pipeline.handle_output_moderation_when_task_finished = lambda answer: "moderated answer" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda answer: f"replace:{answer}" + pipeline._save_message = lambda **kwargs: None + pipeline._message_end_to_stream_response = lambda: "end" + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["replace:moderated answer", "end"] + pipeline._handle_stop.assert_called_once() + + def test_process_stream_response_handles_retriever_unknown_and_empty_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=None)), + ) + handled = {"retriever": 0} + + def _handle_retriever_resources(event): + handled["retriever"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _handle_retriever_resources + pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=retriever_event), + SimpleNamespace(event=SimpleNamespace()), + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + ] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + assert handled["retriever"] == 1 + + def test_process_stream_response_skips_when_output_moderation_directs_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="x")), + ) + pipeline._handle_output_moderation_chunk = lambda text: True + pipeline.queue_manager.listen = lambda: iter([SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk))]) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + + def test_process_stream_response_ignores_unsupported_chunk_content_types(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = SimpleNamespace( + prompt_messages=[], + delta=SimpleNamespace(message=SimpleNamespace(content=[object(), "ok"])), + ) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: kwargs["answer"] + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueLLMChunkEvent.model_construct(chunk=chunk))] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["ok"] + + def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._conversation_name_generate_thread = object() + pipeline.queue_manager.listen = lambda: iter([]) + + assert list(pipeline._process_stream_response(publisher=None)) == [] + + def test_save_message_persists_fields_and_emits_trace(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline.start_at = 10.0 + pipeline._model_config = SimpleNamespace(mode="chat") + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content=" {{name}} hello ") + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata( + {"prompt_tokens": 3, "completion_tokens": 5, "total_price": "1.23"} + ) + + message_obj = SimpleNamespace(id="msg") + conversation_obj = SimpleNamespace(id="conv") + session = Mock() + session.scalar.side_effect = [message_obj, conversation_obj] + trace_manager = SimpleNamespace(add_trace_task=Mock()) + sent_payloads: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptMessageUtil.prompt_messages_to_prompt_for_saving", + lambda mode, prompt_messages: "serialized-prompt", + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptTemplateParser.remove_template_variables", + lambda text: text.replace("{{name}}", "").strip(), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.naive_utc_now", + lambda: datetime(2024, 1, 1, tzinfo=UTC), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.perf_counter", lambda: 15.0 + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.message_was_created.send", + lambda *args, **kwargs: sent_payloads.append((args, kwargs)), + ) + + pipeline._save_message(session=session, trace_manager=trace_manager) + + assert message_obj.message == "serialized-prompt" + assert message_obj.answer == "hello" + assert message_obj.provider_response_latency == 5.0 + assert trace_manager.add_trace_task.called + assert len(sent_payloads) == 1 + + def test_save_message_raises_when_message_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.return_value = None + + with pytest.raises(ValueError, match="message msg not found"): + pipeline._save_message(session=session) + + def test_save_message_raises_when_conversation_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.side_effect = [SimpleNamespace(id="msg"), None] + + with pytest.raises(ValueError, match="Conversation conv not found"): + pipeline._save_message(session=session) + + def test_message_end_to_stream_response_includes_usage_metadata(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 2}) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.id == "msg" + assert response.metadata["usage"]["prompt_tokens"] == 1 + + def test_record_files_returns_none_when_message_has_no_files(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.files is None + + def test_record_files_handles_local_fallback_and_tool_url_variants(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + message_files = [ + SimpleNamespace( + id="mf-local-fallback", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-missing", + type="file", + ), + SimpleNamespace( + id="mf-tool-http", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="http://cdn.example.com/file.txt?x=1", + upload_file_id=None, + type="file", + ), + SimpleNamespace( + id="mf-tool-noext", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/path/toolid", + upload_file_id=None, + type="file", + ), + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else []) + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "local-fallback-signed", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "tool-signed", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files is not None + assert files[0]["url"] == "local-fallback-signed" + assert files[1]["filename"] == "file.txt" + assert files[2]["extension"] == ".bin" + + def test_agent_message_to_stream_response_builds_payload(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + response = pipeline._agent_message_to_stream_response(answer="hello", message_id="msg") + + assert response.id == "msg" + assert response.answer == "hello" + + def test_agent_thought_to_stream_response_returns_none_when_not_found(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="missing")) + + assert response is None + + def test_handle_output_moderation_chunk_appends_token_when_not_directing(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + appended_tokens: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + appended_tokens.append(text) + + pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("next-token") + + assert result is False + assert appended_tokens == ["next-token"] diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a..37dd116470 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod +from dify_graph.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_exc.py b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py new file mode 100644 index 0000000000..9ea7e96e73 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py @@ -0,0 +1,11 @@ +from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError + + +class TestTaskPipelineExceptions: + def test_record_not_found_error_message(self): + err = RecordNotFoundError("Message", "msg-1") + assert str(err) == "Message with id msg-1 not found" + + def test_workflow_run_not_found_error_message(self): + err = WorkflowRunNotFoundError("run-1") + assert str(err) == "WorkflowRun with id run-1 not found" diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index c0c636715d..07ee75ed35 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,12 +1,16 @@ """Unit tests for the message cycle manager optimization.""" +from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from flask import current_app +from flask import Flask, current_app -from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from models.model import AppMode class TestMessageCycleManagerOptimization: @@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE mock_session.scalar.assert_called_once() + def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager): + """Return MESSAGE_FILE directly from in-memory cache without opening a DB session.""" + message_cycle_manager._message_has_file.add("cached-message") + + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + result = message_cycle_manager.get_message_event_type("cached-message") + + assert result == StreamEvent.MESSAGE_FILE + mock_session_factory.create_session.assert_not_called() + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization: assert chunk2_response.event == StreamEvent.MESSAGE assert chunk1_response.answer == "Chunk 1" assert chunk2_response.answer == "Chunk 2" + + def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager): + """Return None when completion entities are used for conversation naming. + + Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity. + Returns: None, indicating no name generation for completion apps. + Side effects: None expected. + """ + + class DummyCompletion: + pass + + with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion): + message_cycle_manager._application_generate_entity = DummyCompletion() + result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi") + + assert result is None + + def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager): + """Spawn background generation thread for the first chat message.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True} + flask_app = object() + + class DummyTimer: + def __init__(self, interval, function, args=None, kwargs=None): + self.interval = interval + self.function = function + self.args = args or [] + self.kwargs = kwargs + self.daemon = False + self.started = False + + def start(self): + self.started = True + + with ( + patch( + "core.app.task_pipeline.message_cycle_manager.current_app", + new=SimpleNamespace(_get_current_object=lambda: flask_app), + ), + patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer), + ): + thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello") + + assert isinstance(thread, DummyTimer) + assert thread.interval == 1 + assert thread.function == message_cycle_manager._generate_conversation_name_worker + assert thread.started is True + assert thread.daemon is True + assert thread.kwargs["flask_app"] is flask_app + assert thread.kwargs["conversation_id"] == "conv-1" + assert thread.kwargs["query"] == "hello" + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + + def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager): + """Skip thread creation when auto naming is disabled but still mark conversation as not new.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False} + + with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer: + result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello") + + assert result is None + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + mock_timer.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager): + """Return early when the conversation cannot be found.""" + flask_app = Flask(__name__) + db_session = Mock() + db_session.scalar.return_value = None + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager): + """Return early when non-completion conversation has no app relation.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id") + db_session = Mock() + db_session.scalar.return_value = conversation + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager): + """Use cached conversation name when present and avoid LLM call.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = b"cached-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "cached-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_llm_generator.generate_conversation_name.assert_not_called() + mock_redis.setex.assert_not_called() + + def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager): + """Generate conversation name and write it to redis cache on cache miss.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.return_value = "generated-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "generated-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager): + """Fallback to truncated query when LLM generation fails.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + long_query = "q" * 60 + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config, + patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed") + mock_dify_config.DEBUG = True + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) + + assert conversation.name == (long_query[:47] + "...") + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_logger.exception.assert_called_once() + + def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager): + """Populate task metadata from annotation reply events. + + Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService. + Returns: The fetched annotation object. + Side effects: Updates metadata.annotation_reply with id and account name. + """ + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + annotation = SimpleNamespace( + id="ann-1", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = annotation + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="ann-1") + ) + + assert result == annotation + assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1" + assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice" + + def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager): + """Return None and keep metadata unchanged when annotation is not found.""" + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = None + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="missing") + ) + + assert result is None + assert message_cycle_manager._task_state.metadata.annotation_reply is None + + def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager): + """Merge retriever resources, deduplicate, and preserve ordering positions. + + Args: message_cycle_manager with show_retrieve_source enabled and existing metadata. + Returns: None. + Side effects: Updates metadata.retriever_resources with unique items and positions. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing])) + + duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2") + + event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource]) + message_cycle_manager.handle_retriever_resources(event) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2 + + def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager): + """Build a stream response with a signed tool file URL. + + Args: message_cycle_manager with mocked Session/db and sign_tool_file. + Returns: MessageStreamResponse with signed url and belongs_to normalized to user. + Side effects: Calls sign_tool_file for tool file ids. + """ + message_cycle_manager._application_generate_entity.task_id = "task-1" + + message_file = SimpleNamespace( + id="file-1", + type="image", + belongs_to=None, + url="tool://file.verylongextension", + message_id="msg-1", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-url" + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1")) + + assert response.url == "signed-url" + assert response.belongs_to == "user" + mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin") + + def test_handle_retriever_resources_requires_features(self, message_cycle_manager): + """Raise when retriever resources are handled without feature config. + + Args: message_cycle_manager with additional_features unset and empty metadata. + Raises: ValueError when show_retrieve_source configuration is missing. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with pytest.raises(ValueError): + message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[])) + + def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager): + """Ignore null resource entries while preserving valid resources.""" + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[])) + resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + + message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource])) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + + def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager): + """Use original URL when message file URL is already HTTP.""" + message_cycle_manager._application_generate_entity.task_id = "task-http" + message_file = SimpleNamespace( + id="file-http", + type="image", + belongs_to="assistant", + url="http://example.com/pic.png", + message_id="msg-http", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-http") + ) + + assert response is not None + assert response.url == "http://example.com/pic.png" + assert "msg-http" in message_cycle_manager._message_has_file + + def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager): + """Default tool file extension to .bin when URL has no extension part.""" + message_cycle_manager._application_generate_entity.task_id = "task-bin" + message_file = SimpleNamespace( + id="file-bin", + type="file", + belongs_to="assistant", + url="tool-file-id", + message_id="msg-bin", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-bin-url" + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-bin") + ) + + assert response is not None + assert response.url == "signed-bin-url" + mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin") + + def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager): + """Return None when message file lookup does not find a record.""" + session = Mock() + session.scalar.return_value = None + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing")) + + assert response is None + + def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager): + """Include the provided replacement reason in the stream payload.""" + response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation") + + assert response.answer == "replaced" + assert response.reason == "moderation" diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py new file mode 100644 index 0000000000..0f8a846d11 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -0,0 +1,60 @@ +from datetime import UTC, datetime +from unittest.mock import Mock + +import pytest + +from core.app.workflow.layers.persistence import ( + PersistenceWorkflowInfo, + WorkflowPersistenceLayer, + _NodeRuntimeSnapshot, +) +from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType +from dify_graph.node_events import NodeRunResult + + +def _build_layer() -> WorkflowPersistenceLayer: + application_generate_entity = Mock() + application_generate_entity.inputs = {} + + return WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={}, + ), + workflow_execution_repository=Mock(), + workflow_node_execution_repository=Mock(), + ) + + +def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-1" + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="node-id", + title="LLM", + predecessor_node_id=None, + iteration_id="iter-1", + loop_id=None, + created_at=node_execution.created_at, + ) + + event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None) + delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None) + monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time) + + layer._update_node_execution( + node_execution, + NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), + WorkflowNodeExecutionStatus.SUCCEEDED, + finished_at=event_finished_at, + ) + + assert node_execution.finished_at == event_finished_at + assert node_execution.elapsed_time == 2.0 diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py new file mode 100644 index 0000000000..fb76f22a2a --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -0,0 +1,43 @@ +from unittest.mock import patch + +from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime + + +class TestDifyWorkflowFileRuntime: + def test_runtime_properties_and_helpers(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "http://files") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", "http://internal") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 123) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + + runtime = DifyWorkflowFileRuntime() + + assert runtime.files_url == "http://files" + assert runtime.internal_files_url == "http://internal" + assert runtime.secret_key == "secret" + assert runtime.files_access_timeout == 123 + assert runtime.multimodal_send_format == "url" + + with patch("core.app.workflow.file_runtime.ssrf_proxy.get") as mock_get: + mock_get.return_value = "response" + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch("core.app.workflow.file_runtime.storage.load") as mock_load: + mock_load.return_value = b"data" + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + with patch("core.app.workflow.file_runtime.sign_tool_file") as mock_sign: + mock_sign.return_value = "signed" + assert runtime.sign_tool_file(tool_file_id="id", extension=".txt", for_external=False) == "signed" + mock_sign.assert_called_once_with(tool_file_id="id", extension=".txt", for_external=False) + + def test_bind_runtime_registers_instance(self): + with patch("core.app.workflow.file_runtime.set_workflow_file_runtime") as mock_set: + bind_dify_workflow_file_runtime() + + mock_set.assert_called_once() + runtime = mock_set.call_args[0][0] + assert isinstance(runtime, DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py new file mode 100644 index 0000000000..9e742507c6 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -0,0 +1,161 @@ +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import BuiltinNodeTypes + + +class DummyNode: + def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = id + self.config = config + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self.kwargs = kwargs + + +class DummyCodeNode(DummyNode): + @classmethod + def default_code_providers(cls): + return () + + +class DummyTemplateTransformNode(DummyNode): + pass + + +class DummyHttpRequestNode(DummyNode): + pass + + +class DummyKnowledgeRetrievalNode(DummyNode): + pass + + +class DummyDocumentExtractorNode(DummyNode): + pass + + +class TestDifyNodeFactory: + @staticmethod + def _stub_node_resolution(monkeypatch, node_class): + monkeypatch.setattr( + "core.workflow.node_factory.resolve_workflow_node_class", + lambda **_kwargs: node_class, + ) + + def _factory(self, monkeypatch): + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100) + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u") + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key") + + run_context = build_dify_run_context( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + return DifyNodeFactory( + graph_init_params=SimpleNamespace(run_context=run_context), + graph_runtime_state=SimpleNamespace(), + ) + + def test_create_node_unknown_type(self, monkeypatch): + factory = self._factory(monkeypatch) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": "unknown"}}) + + def test_create_node_missing_mapping(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {}) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_missing_latest_class(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr( + "core.workflow.node_factory.get_node_type_classes_mapping", + lambda: {BuiltinNodeTypes.START: {"1": None}}, + ) + monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest") + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_selects_versioned_class(self, monkeypatch): + factory = self._factory(monkeypatch) + selected_versions: list[tuple[str, str]] = [] + + class DummyNodeV2(DummyNode): + pass + + def _get_mapping(): + selected_versions.append(("snapshot", "called")) + return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}} + + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}}) + + assert isinstance(node, DummyNodeV2) + assert node.id == "node-1" + assert selected_versions == [("snapshot", "called")] + + def test_create_node_code_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyCodeNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}}) + + assert isinstance(node, DummyCodeNode) + assert node.id == "node-1" + + def test_create_node_template_transform_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) + + assert isinstance(node, DummyTemplateTransformNode) + assert "template_renderer" in node.kwargs + + def test_create_node_http_request_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyHttpRequestNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}) + + assert isinstance(node, DummyHttpRequestNode) + assert "http_request_config" in node.kwargs + + def test_create_node_knowledge_retrieval_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}) + + assert isinstance(node, DummyKnowledgeRetrievalNode) + assert node.kwargs == {} + + def test_create_node_document_extractor_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}}) + + assert isinstance(node, DummyDocumentExtractorNode) + assert "unstructured_api_config" in node.kwargs diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py new file mode 100644 index 0000000000..0565f4cfe9 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from core.app.workflow.layers.observability import ObservabilityLayer +from dify_graph.enums import BuiltinNodeTypes + + +class TestObservabilityLayerExtras: + def test_init_tracer_enabled_sets_tracer(self, monkeypatch): + tracer = object() + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer) + + layer = ObservabilityLayer() + + assert layer._is_disabled is False + assert layer._tracer is tracer + + def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + def _raise(*_args, **_kwargs): + raise RuntimeError("tracer init failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + assert layer._tracer is None + assert "Failed to get OpenTelemetry tracer" in caplog.text + + def test_init_tracer_disables_when_otel_disabled(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + + def test_get_parser_uses_registry_when_node_type_matches(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL)) + + assert parser is layer._parsers[BuiltinNodeTypes.TOOL] + + def test_get_parser_defaults_when_node_type_missing(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=None)) + + assert parser is layer._default_parser + + def test_on_graph_start_clears_contexts(self): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_start() + + assert layer._node_contexts == {} + + def test_on_event_is_noop(self): + layer = ObservabilityLayer() + + layer.on_event(object()) + + def test_on_graph_end_clears_unfinished_contexts(self, caplog): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_end(error=None) + + assert layer._node_contexts == {} + assert "node spans were not properly ended" in caplog.text + + def test_on_node_run_start_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + layer._tracer = None + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object()) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self): + layer = ObservabilityLayer() + layer._is_disabled = False + calls: list[str] = [] + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called")) + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert calls == [] + + def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + def _raise(*_args, **_kwargs): + raise RuntimeError("start failed") + + layer._tracer = SimpleNamespace(start_span=_raise) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert "Failed to create OpenTelemetry span for node" in caplog.text + + def test_on_node_run_end_without_context_noop(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None) + + assert "exec" in layer._node_contexts + + def test_on_node_run_end_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_calls_span_end(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + ended: list[str] = [] + + class _Parser: + def parse(self, **_kwargs): + return None + + span = SimpleNamespace(end=lambda: ended.append("ended")) + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert ended == ["ended"] + assert "exec" not in layer._node_contexts + + def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + class _Parser: + def parse(self, **_kwargs): + return None + + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token") + + def _raise(*_args, **_kwargs): + raise RuntimeError("detach failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert "Failed to detach OpenTelemetry token" in caplog.text + assert "exec" not in layer._node_contexts + + def test_on_node_run_start_and_end_creates_span(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + + span = SimpleNamespace(end=lambda: None) + tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span) + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token") + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None) + + layer._tracer = tracer + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + + layer.on_node_run_start(node) + assert "exec" in layer._node_contexts + + layer.on_node_run_end(node, error=None) + assert "exec" not in layer._node_contexts diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py new file mode 100644 index 0000000000..45f6a0c7a1 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + SystemVariableKey, + WorkflowExecutionStatus, + WorkflowNodeExecutionStatus, + WorkflowType, +) +from dify_graph.graph_events.graph import ( + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from dify_graph.graph_events.node import ( + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from dify_graph.node_events import NodeRunResult +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.system_variable import SystemVariable + + +class _RepoRecorder: + def __init__(self) -> None: + self.saved: list[object] = [] + self.saved_exec_data: list[object] = [] + + def save(self, entity): + self.saved.append(entity) + + def save_execution_data(self, entity): + self.saved_exec_data.append(entity) + + +def _naive_utc_now() -> datetime: + return datetime.now(UTC).replace(tzinfo=None) + + +def _make_layer( + system_variable: SystemVariable | None = None, + *, + extras: dict | None = None, + trace_manager: object | None = None, +): + system_variable = system_variable or SystemVariable(workflow_execution_id="run-id", conversation_id="conv-id") + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variable), start_at=0.0) + read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) + + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=SimpleNamespace(app_id="app", tenant_id="tenant"), + inputs={"foo": "bar"}, + files=[], + user_id="user", + stream=False, + invoke_from=None, + trace_manager=None, + workflow_execution_id="run-id", + extras=extras or {}, + call_depth=0, + ) + + workflow_info = PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={"nodes": [], "edges": []}, + ) + + workflow_execution_repo = _RepoRecorder() + workflow_node_execution_repo = _RepoRecorder() + + layer = WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=workflow_info, + workflow_execution_repository=workflow_execution_repo, + workflow_node_execution_repository=workflow_node_execution_repo, + trace_manager=trace_manager, + ) + layer.initialize(read_only_state, command_channel=None) + + return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state + + +class TestWorkflowPersistenceLayer: + def test_on_graph_start_resets_state(self): + layer, _, _, _ = _make_layer() + layer._workflow_execution = object() + layer._node_execution_cache["cached"] = object() + layer._node_snapshots["cached"] = object() + layer._node_sequence = 9 + + layer.on_graph_start() + + assert layer._workflow_execution is None + assert layer._node_execution_cache == {} + assert layer._node_snapshots == {} + assert layer._node_sequence == 0 + + def test_get_execution_id_requires_system_variable(self): + system_variable = SystemVariable(workflow_execution_id=None) + layer, _, _, _ = _make_layer(system_variable) + + with pytest.raises(ValueError, match="workflow_execution_id must be provided"): + layer._get_execution_id() + + def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch): + layer, _, _, _ = _make_layer() + + monkeypatch.setattr( + "core.workflow.workflow_entry.WorkflowEntry.handle_special_values", + lambda inputs: inputs, + ) + + inputs = layer._prepare_workflow_inputs() + + assert "sys.conversation_id" not in inputs + assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id" + + def test_fail_running_node_executions_marks_failed(self): + layer, _, node_repo, _ = _make_layer() + + execution = WorkflowNodeExecution( + id="exec-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[execution.id] = execution + + layer._fail_running_node_executions(error_message="boom") + + assert execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_repo.saved + + def test_handle_graph_run_started_saves_execution(self): + layer, exec_repo, _, _ = _make_layer() + + layer._handle_graph_run_started() + + assert exec_repo.saved + + def test_handle_graph_run_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 3 + runtime_state.node_run_steps = 2 + runtime_state.outputs = {"out": "v"} + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.SUCCEEDED + assert saved.total_tokens == 3 + assert saved.total_steps == 2 + + def test_handle_graph_run_partial_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 5 + runtime_state.node_run_steps = 4 + runtime_state._graph_execution = SimpleNamespace(exceptions_count=2) + + layer._handle_graph_run_partial_succeeded( + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2) + ) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert saved.exceptions_count == 2 + assert saved.total_tokens == 5 + + def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self): + trace_tasks: list[object] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager) + layer._handle_graph_run_started() + + running = WorkflowNodeExecution( + id="node-exec", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[running.id] = running + + layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1)) + + assert node_repo.saved + assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED + assert trace_tasks + + def test_handle_graph_run_aborted_sets_status(self): + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + + layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.STOPPED + assert saved.error_message + + def test_handle_graph_run_paused_updates_outputs(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 7 + runtime_state.node_run_steps = 5 + + layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PAUSED + assert saved.outputs == {"pause": True} + assert saved.finished_at is None + + def test_handle_node_started_and_retry(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + predecessor_node_id="prev", + in_iteration_id="iter", + in_loop_id="loop", + ) + layer._handle_node_started(start_event) + + assert node_repo.saved + assert "exec" in layer._node_execution_cache + assert layer._node_snapshots["exec"].node_id == "node" + + retry_event = NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + layer._handle_node_retry(retry_event) + assert node_repo.saved_exec_data + + def test_handle_node_result_events_update_execution(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={}) + success_event = NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + node_run_result=result, + ) + layer._handle_node_succeeded(success_event) + + failed_event = NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="boom", + node_run_result=result, + ) + layer._handle_node_failed(failed_event) + + exception_event = NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="err", + node_run_result=result, + ) + layer._handle_node_exception(exception_event) + + assert node_repo.saved_exec_data + + def test_handle_node_pause_requested_skips_outputs(self): + layer, _, _, _ = _make_layer() + layer._handle_graph_run_started() + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + domain_execution = layer._node_execution_cache["exec"] + domain_execution.inputs = {"old": True} + + result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={}) + pause_event = NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + reason=SchedulingPause(message="pause"), + node_run_result=result, + ) + layer._handle_node_pause_requested(pause_event) + + assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED + assert domain_execution.inputs == {"old": True} + + def test_get_node_execution_raises_for_missing(self): + layer, _, _, _ = _make_layer() + with pytest.raises(ValueError, match="Node execution not found"): + layer._get_node_execution("missing") + + def test_get_workflow_execution_raises_when_uninitialized(self): + layer, _, _, _ = _make_layer() + + with pytest.raises(ValueError, match="workflow execution not initialized"): + layer._get_workflow_execution() + + def test_next_node_sequence_increments(self): + layer, _, _, _ = _make_layer() + assert layer._next_node_sequence() == 1 + assert layer._next_node_sequence() == 2 + + def test_on_graph_end_is_noop(self): + layer, _, _, _ = _make_layer() + + assert layer.on_graph_end(error=None) is None + + def test_on_event_dispatches_to_all_known_handlers(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_graph_run_started = _record("started") + layer._handle_graph_run_succeeded = _record("succeeded") + layer._handle_graph_run_partial_succeeded = _record("partial") + layer._handle_graph_run_failed = _record("failed") + layer._handle_graph_run_aborted = _record("aborted") + layer._handle_graph_run_paused = _record("paused") + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + layer._handle_node_succeeded = _record("node_succeeded") + layer._handle_node_failed = _record("node_failed") + layer._handle_node_exception = _record("node_exception") + layer._handle_node_pause_requested = _record("node_paused") + + node_result = NodeRunResult() + now = _naive_utc_now() + events = [ + GraphRunStartedEvent(), + GraphRunSucceededEvent(outputs={"ok": True}), + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1), + GraphRunFailedEvent(error="boom", exceptions_count=1), + GraphRunAbortedEvent(reason="stop", outputs={"x": 1}), + GraphRunPausedEvent(outputs={"pause": True}), + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + ), + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + error="retry", + retry_index=1, + ), + NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + node_run_result=node_result, + ), + NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="failed", + node_run_result=node_result, + ), + NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="error", + node_run_result=node_result, + ), + NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + reason=SchedulingPause(message="pause"), + node_run_result=node_result, + ), + ] + expected_order = [ + "started", + "succeeded", + "partial", + "failed", + "aborted", + "paused", + "node_started", + "node_retry", + "node_succeeded", + "node_failed", + "node_exception", + "node_paused", + ] + + for event in events: + layer.on_event(event) + + assert called == expected_order + + def test_on_event_dispatches_retry_before_started_for_retry_event(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + + layer.on_event( + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + ) + + assert called == ["node_retry"] + + def test_enqueue_trace_task_skips_when_disabled(self): + trace_tasks: list[object] = [] + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + assert exec_repo.saved + assert not trace_tasks diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py index a7c93242cd..7cd1fdf06b 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -166,6 +166,7 @@ class TestDatasourceFileManager: # Setup mock_guess_ext.return_value = None # Cannot guess mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" # Execute upload_file = DatasourceFileManager.create_file_by_raw( diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 5ebefcd8d2..95d58757f1 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) +from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -409,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() - configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session) - assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value + assert existing_record.preferred_provider_type == ProviderType.SYSTEM session.commit.assert_called_once() @@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-base", name="LB Base", credentials={}, - credential_source_type="provider", + credential_source_type=CredentialSourceType.PROVIDER, ) ], ), @@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-custom", name="LB Custom", credentials={}, - credential_source_type="custom_model", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) ], ), @@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None: configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) @@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) diff --git a/api/tests/unit_tests/core/moderation/api/test_api.py b/api/tests/unit_tests/core/moderation/api/test_api.py new file mode 100644 index 0000000000..558b20e5f8 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/api/test_api.py @@ -0,0 +1,181 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.extension.api_based_extension_requestor import APIBasedExtensionPoint +from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams +from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult +from models.api_based_extension import APIBasedExtension + + +class TestApiModeration: + @pytest.fixture + def api_config(self): + return { + "inputs_config": { + "enabled": True, + }, + "outputs_config": { + "enabled": True, + }, + "api_based_extension_id": "test-extension-id", + } + + @pytest.fixture + def api_moderation(self, api_config): + return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config) + + def test_moderation_input_params(self): + params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query") + assert params.app_id == "app-1" + assert params.inputs == {"key": "val"} + assert params.query == "test query" + + # Test defaults + params_default = ModerationInputParams() + assert params_default.app_id == "" + assert params_default.inputs == {} + assert params_default.query == "" + + def test_moderation_output_params(self): + params = ModerationOutputParams(app_id="app-1", text="test text") + assert params.app_id == "app-1" + assert params.text == "test text" + + with pytest.raises(ValidationError): + ModerationOutputParams() + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_validate_config_success(self, mock_get_extension, api_config): + mock_get_extension.return_value = MagicMock(spec=APIBasedExtension) + ApiModeration.validate_config("test-tenant-id", api_config) + mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id") + + def test_validate_config_missing_extension_id(self): + config = { + "inputs_config": {"enabled": True}, + "outputs_config": {"enabled": True}, + } + with pytest.raises(ValueError, match="api_based_extension_id is required"): + ApiModeration.validate_config("test-tenant-id", config) + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_validate_config_extension_not_found(self, mock_get_extension, api_config): + mock_get_extension.return_value = None + with pytest.raises(ValueError, match="API-based Extension not found"): + ApiModeration.validate_config("test-tenant-id", api_config) + + @patch("core.moderation.api.api.ApiModeration._get_config_by_requestor") + def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation): + mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"} + + result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello") + + assert isinstance(result, ModerationInputsResult) + assert result.flagged is True + assert result.action == ModerationAction.DIRECT_OUTPUT + assert result.preset_response == "Blocked by API" + + mock_get_config.assert_called_once_with( + APIBasedExtensionPoint.APP_MODERATION_INPUT, + {"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"}, + ) + + def test_moderation_for_inputs_disabled(self): + config = { + "inputs_config": {"enabled": False}, + "outputs_config": {"enabled": True}, + "api_based_extension_id": "ext-id", + } + moderation = ApiModeration("app-id", "tenant-id", config) + result = moderation.moderation_for_inputs(inputs={}, query="") + + assert result.flagged is False + assert result.action == ModerationAction.DIRECT_OUTPUT + assert result.preset_response == "" + + def test_moderation_for_inputs_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation.moderation_for_inputs({}, "") + + @patch("core.moderation.api.api.ApiModeration._get_config_by_requestor") + def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation): + mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""} + + result = api_moderation.moderation_for_outputs(text="hello world") + + assert isinstance(result, ModerationOutputsResult) + assert result.flagged is False + + mock_get_config.assert_called_once_with( + APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"} + ) + + def test_moderation_for_outputs_disabled(self): + config = { + "inputs_config": {"enabled": True}, + "outputs_config": {"enabled": False}, + "api_based_extension_id": "ext-id", + } + moderation = ApiModeration("app-id", "tenant-id", config) + result = moderation.moderation_for_outputs(text="test") + + assert result.flagged is False + assert result.action == ModerationAction.DIRECT_OUTPUT + + def test_moderation_for_outputs_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation.moderation_for_outputs("test") + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + @patch("core.moderation.api.api.decrypt_token") + @patch("core.moderation.api.api.APIBasedExtensionRequestor") + def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation): + mock_ext = MagicMock(spec=APIBasedExtension) + mock_ext.api_endpoint = "http://api.test" + mock_ext.api_key = "encrypted-key" + mock_get_ext.return_value = mock_ext + + mock_decrypt.return_value = "decrypted-key" + + mock_requestor = MagicMock() + mock_requestor.request.return_value = {"flagged": True} + mock_requestor_cls.return_value = mock_requestor + + params = {"some": "params"} + result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params) + + assert result == {"flagged": True} + mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id") + mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key") + mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key") + mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params) + + def test_get_config_by_requestor_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {}) + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation): + mock_get_ext.return_value = None + with pytest.raises(ValueError, match="API-based Extension not found"): + api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {}) + + @patch("core.moderation.api.api.db.session.scalar") + def test_get_api_based_extension(self, mock_scalar): + mock_ext = MagicMock(spec=APIBasedExtension) + mock_scalar.return_value = mock_ext + + result = ApiModeration._get_api_based_extension("tenant-1", "ext-1") + + assert result == mock_ext + mock_scalar.assert_called_once() + # Verify the call has the correct filters + args, kwargs = mock_scalar.call_args + stmt = args[0] + # We can't easily inspect the statement without complex sqlalchemy tricks, + # but calling it is usually enough for unit tests if we mock the result. diff --git a/api/tests/unit_tests/core/moderation/test_input_moderation.py b/api/tests/unit_tests/core/moderation/test_input_moderation.py new file mode 100644 index 0000000000..2dbc80cf14 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/test_input_moderation.py @@ -0,0 +1,207 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity +from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult +from core.moderation.input_moderation import InputModeration +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager + + +class TestInputModeration: + @pytest.fixture + def app_config(self): + config = MagicMock(spec=AppConfig) + config.sensitive_word_avoidance = None + return config + + @pytest.fixture + def input_moderation(self): + return InputModeration() + + def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is False + assert final_inputs == inputs + assert final_query == query + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {"keywords": ["bad"]} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is False + assert final_inputs == inputs + assert final_query == query + mock_factory_cls.assert_called_once_with( + name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]} + ) + mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query) + + @patch("core.moderation.input_moderation.ModerationFactory") + @patch("core.moderation.input_moderation.TraceTask") + def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + trace_manager = MagicMock(spec=TraceQueueManager) + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_inputs.return_value = mock_result + + input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + trace_manager=trace_manager, + ) + + trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value) + mock_trace_task.assert_called_once() + call_kwargs = mock_trace_task.call_args.kwargs + call_args = mock_trace_task.call_args.args + assert call_args[0] == TraceTaskName.MODERATION_TRACE + assert call_kwargs["message_id"] == message_id + assert call_kwargs["moderation_result"] == mock_result + assert call_kwargs["inputs"] == inputs + assert "timer" in call_kwargs + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content" + ) + mock_factory.moderation_for_inputs.return_value = mock_result + + with pytest.raises(ModerationError) as excinfo: + input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + ) + + assert str(excinfo.value) == "Blocked content" + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult( + flagged=True, + action=ModerationAction.OVERRIDDEN, + inputs={"input_key": "overridden_value"}, + query="overridden query", + ) + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is True + assert final_inputs == {"input_key": "overridden_value"} + assert final_query == "overridden query" + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = MagicMock() + mock_result.flagged = True + mock_result.action = "NONE" # Some other action + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + ) + + assert flagged is True + assert final_inputs == inputs + assert final_query == query diff --git a/api/tests/unit_tests/core/moderation/test_output_moderation.py b/api/tests/unit_tests/core/moderation/test_output_moderation.py new file mode 100644 index 0000000000..c6a7cd3f61 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/test_output_moderation.py @@ -0,0 +1,234 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import QueueMessageReplaceEvent +from core.moderation.base import ModerationAction, ModerationOutputsResult +from core.moderation.output_moderation import ModerationRule, OutputModeration + + +class TestOutputModeration: + @pytest.fixture + def mock_queue_manager(self): + return MagicMock(spec=AppQueueManager) + + @pytest.fixture + def moderation_rule(self): + return ModerationRule(type="keywords", config={"keywords": "badword"}) + + @pytest.fixture + def output_moderation(self, mock_queue_manager, moderation_rule): + return OutputModeration( + tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager + ) + + def test_should_direct_output(self, output_moderation): + assert output_moderation.should_direct_output() is False + output_moderation.final_output = "blocked" + assert output_moderation.should_direct_output() is True + + def test_get_final_output(self, output_moderation): + assert output_moderation.get_final_output() == "" + output_moderation.final_output = "blocked" + assert output_moderation.get_final_output() == "blocked" + + def test_append_new_token(self, output_moderation): + with patch.object(OutputModeration, "start_thread") as mock_start: + output_moderation.append_new_token("hello") + assert output_moderation.buffer == "hello" + mock_start.assert_called_once() + + output_moderation.thread = MagicMock() + output_moderation.append_new_token(" world") + assert output_moderation.buffer == "hello world" + assert mock_start.call_count == 1 + + def test_moderation_completion_no_flag(self, output_moderation): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + + output, flagged = output_moderation.moderation_completion("safe content") + + assert output == "safe content" + assert flagged is False + assert output_moderation.is_final_chunk is True + + def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" + ) + + output, flagged = output_moderation.moderation_completion("badword content", public_event=True) + + assert output == "preset" + assert flagged is True + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert isinstance(args[0], QueueMessageReplaceEvent) + assert args[0].text == "preset" + assert args[1] == PublishFrom.TASK_PIPELINE + + def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content" + ) + + output, flagged = output_moderation.moderation_completion("badword content", public_event=True) + + assert output == "masked content" + assert flagged is True + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert args[0].text == "masked content" + + def test_start_thread(self, output_moderation): + mock_app = MagicMock(spec=Flask) + with patch("core.moderation.output_moderation.current_app") as mock_current_app: + mock_current_app._get_current_object.return_value = mock_app + with patch("threading.Thread") as mock_thread_class: + mock_thread_instance = MagicMock() + mock_thread_class.return_value = mock_thread_instance + + thread = output_moderation.start_thread() + + assert thread == mock_thread_instance + mock_thread_class.assert_called_once() + mock_thread_instance.start.assert_called_once() + + def test_stop_thread(self, output_moderation): + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + output_moderation.thread = mock_thread + + output_moderation.stop_thread() + assert output_moderation.thread_running is False + + output_moderation.thread_running = True + mock_thread.is_alive.return_value = False + output_moderation.stop_thread() + assert output_moderation.thread_running is True + + @patch("core.moderation.output_moderation.ModerationFactory") + def test_moderation_success(self, mock_factory_class, output_moderation): + mock_factory = mock_factory_class.return_value + mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_outputs.return_value = mock_result + + result = output_moderation.moderation("tenant", "app", "buffer") + + assert result == mock_result + mock_factory_class.assert_called_once_with( + name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"} + ) + + @patch("core.moderation.output_moderation.ModerationFactory") + def test_moderation_exception(self, mock_factory_class, output_moderation): + mock_factory_class.side_effect = Exception("error") + + result = output_moderation.moderation("tenant", "app", "buffer") + assert result is None + + def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + # Test exit on thread_running=False + output_moderation.thread_running = False + output_moderation.worker(mock_app, 10) + # Should exit immediately + + def test_worker_no_flag(self, output_moderation): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + + output_moderation.buffer = "safe" + output_moderation.is_final_chunk = True + + # To avoid infinite loop, we'll set thread_running to False after one iteration + def side_effect(*args, **kwargs): + output_moderation.thread_running = False + return mock_moderation.return_value + + mock_moderation.side_effect = side_effect + + output_moderation.worker(mock_app, 10) + + assert mock_moderation.called + + def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" + ) + + output_moderation.buffer = "badword" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + assert output_moderation.final_output == "preset" + mock_queue_manager.publish.assert_called_once() + # It breaks on DIRECT_OUTPUT + + def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + # Use side_effect to change thread_running on second call + def side_effect(*args, **kwargs): + if mock_moderation.call_count > 1: + output_moderation.thread_running = False + return None + return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked") + + mock_moderation.side_effect = side_effect + + output_moderation.buffer = "badword" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert args[0].text == "masked" + + def test_worker_chunk_too_small(self, output_moderation): + mock_app = MagicMock(spec=Flask) + with patch("time.sleep") as mock_sleep: + # chunk_length < buffer_size and not is_final_chunk + output_moderation.buffer = "123" # length 3 + output_moderation.is_final_chunk = False + + def sleep_side_effect(seconds): + output_moderation.thread_running = False + + mock_sleep.side_effect = sleep_side_effect + + output_moderation.worker(mock_app, 10) # buffer_size 10 + + mock_sleep.assert_called_once_with(1) + + def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + with patch.object(OutputModeration, "moderation") as mock_moderation: + # Return None (exception or no rule) + mock_moderation.return_value = None + + def side_effect(*args, **kwargs): + output_moderation.thread_running = False + + mock_moderation.side_effect = side_effect + + output_moderation.buffer = "something" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + mock_queue_manager.publish.assert_not_called() diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py index 65ee62b8dd..c7a4265a95 100644 --- a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -211,3 +211,16 @@ class TestCleanProcessor: text = "[Text with (parens) and symbols](https://example.com)" expected = "[Text with (parens) and symbols](https://example.com)" assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_remove_urls_emails_preserves_markdown_image_links(self): + """Remove plain URLs and emails while preserving markdown image links.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)" + result = CleanProcessor.clean(text, process_rule) + + assert result == "Email and remove but keep ![diagram](https://example.com/image.png)" + + def test_filter_string_returns_input_text(self): + """Test filter_string passthrough behavior.""" + processor = CleanProcessor() + assert processor.filter_string("raw text") == "raw text" diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py new file mode 100644 index 0000000000..538457ccc8 --- /dev/null +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -0,0 +1,249 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + +def _doc(content: str) -> Document: + return Document(page_content=content) + + +class TestDataPostProcessor: + def test_init_sets_rerank_and_reorder_runners(self): + rerank_runner = object() + reorder_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock: + with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock: + processor = DataPostProcessor( + tenant_id="tenant-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_model={"config": "value"}, + weights={"weight": "value"}, + reorder_enabled=True, + ) + + assert processor.rerank_runner is rerank_runner + assert processor.reorder_runner is reorder_runner + rerank_mock.assert_called_once_with( + RerankMode.WEIGHTED_SCORE, + "tenant-1", + {"config": "value"}, + {"weight": "value"}, + ) + reorder_mock.assert_called_once_with(True) + + def test_invoke_applies_rerank_then_reorder(self): + original_documents = [_doc("doc-a")] + reranked_documents = [_doc("doc-b")] + reordered_documents = [_doc("doc-c")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = MagicMock() + processor.rerank_runner.run.return_value = reranked_documents + processor.reorder_runner = MagicMock() + processor.reorder_runner.run.return_value = reordered_documents + + result = processor.invoke( + query="how to test", + documents=original_documents, + score_threshold=0.3, + top_n=2, + user="user-1", + query_type=QueryType.IMAGE_QUERY, + ) + + processor.rerank_runner.run.assert_called_once_with( + "how to test", + original_documents, + 0.3, + 2, + "user-1", + QueryType.IMAGE_QUERY, + ) + processor.reorder_runner.run.assert_called_once_with(reranked_documents) + assert result == reordered_documents + + def test_invoke_returns_original_documents_when_no_runner_is_configured(self): + documents = [_doc("doc-a"), _doc("doc-b")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = None + processor.reorder_runner = None + + assert processor.invoke(query="query", documents=documents) == documents + + def test_get_rerank_runner_for_weighted_score(self): + weights_config = { + "vector_setting": { + "vector_weight": 0.7, + "embedding_provider_name": "provider-x", + "embedding_model_name": "embedding-y", + }, + "keyword_setting": {"keyword_weight": 0.3}, + } + expected_runner = object() + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.WEIGHTED_SCORE, + tenant_id="tenant-1", + reranking_model=None, + weights=weights_config, + ) + + assert result is expected_runner + kwargs = factory_mock.call_args.kwargs + assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE + assert kwargs["tenant_id"] == "tenant-1" + assert kwargs["weights"].vector_setting.vector_weight == 0.7 + assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x" + assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y" + assert kwargs["weights"].keyword_setting.keyword_weight == 0.3 + + def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + reranking_model = { + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + } + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock: + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner" + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model=reranking_model, + weights=None, + ) + + assert result is None + model_mock.assert_called_once_with("tenant-1", reranking_model) + factory_mock.assert_not_called() + + def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + expected_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance): + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + }, + weights=None, + ) + + assert result is expected_runner + factory_mock.assert_called_once_with( + runner_type=RerankMode.RERANKING_MODEL, + rerank_model_instance=model_instance, + ) + + def test_get_rerank_runner_returns_none_for_unsupported_mode(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None + assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None + + def test_get_reorder_runner_by_flag(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert isinstance(processor._get_reorder_runner(True), ReorderRunner) + assert processor._get_reorder_runner(False) is None + + def test_get_rerank_model_instance_returns_none_when_config_is_missing(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + assert processor._get_rerank_model_instance("tenant-1", None) is None + + def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: + manager_instance = manager_cls.return_value + with pytest.raises(KeyError, match="reranking_model_name"): + processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) + + manager_instance.get_model_instance.assert_not_called() + + def test_get_rerank_model_instance_success(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: + manager_instance = manager_cls.return_value + manager_instance.get_model_instance.return_value = model_instance + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is model_instance + manager_instance.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider-x", + model_type=ModelType.RERANK, + model="reranker-1", + ) + + def test_get_rerank_model_instance_handles_authorization_error(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: + manager_instance = manager_cls.return_value + manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is None + + +class TestReorderRunner: + def test_run_reorders_even_sized_document_list(self): + documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")] + + reordered = ReorderRunner().run(documents) + + assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"] + + def test_run_handles_odd_sized_and_empty_document_lists(self): + odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")] + runner = ReorderRunner() + + odd_reordered = runner.run(odd_documents) + + assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"] + assert runner.run([]) == [] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py new file mode 100644 index 0000000000..795a325a6b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -0,0 +1,414 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.keyword.jieba.jieba as jieba_module +from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default +from core.rag.models.document import Document + + +class _DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _Field: + def __init__(self, name: str): + self._name = name + + def __eq__(self, other): + return ("eq", self._name, other) + + def in_(self, values): + return ("in", self._name, tuple(values)) + + +class _FakeQuery: + def __init__(self): + self.where_calls: list[tuple] = [] + + def where(self, *conditions): + self.where_calls.append(conditions) + return self + + +class _FakeExecuteResult: + def __init__(self, segments: list[SimpleNamespace]): + self._segments = segments + + def scalars(self): + return self + + def all(self): + return self._segments + + +class _FakeSelect: + def __init__(self): + self.where_conditions: tuple | None = None + + def where(self, *conditions): + self.where_conditions = conditions + return self + + +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): + return SimpleNamespace( + data_source_type=data_source_type, + keyword_table_dict=keyword_table_dict, + keyword_table="", + ) + + +def _dataset(dataset_keyword_table=None, keyword_number=None): + return SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + keyword_number=keyword_number, + dataset_keyword_table=dataset_keyword_table, + ) + + +@pytest.fixture +def patched_runtime(monkeypatch): + session = MagicMock() + db = SimpleNamespace(session=session) + storage = MagicMock() + lock = MagicMock(return_value=_DummyLock()) + redis_client = SimpleNamespace(lock=lock) + + monkeypatch.setattr(jieba_module, "db", db) + monkeypatch.setattr(jieba_module, "storage", storage) + monkeypatch.setattr(jieba_module, "redis_client", redis_client) + + return SimpleNamespace(session=session, storage=storage, lock=lock) + + +def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime): + dataset = _dataset(_dataset_keyword_table(), keyword_number=2) + keyword = Jieba(dataset) + handler = MagicMock() + handler.extract_keywords.return_value = {"kw1", "kw2"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + result = keyword.create( + [ + Document(page_content="alpha", metadata={"doc_id": "node-1"}), + SimpleNamespace(page_content="ignored", metadata=None), + ] + ) + + assert result is keyword + keyword._update_segment_keywords.assert_called_once() + call_args = keyword._update_segment_keywords.call_args.args + assert call_args[0] == "dataset-1" + assert call_args[1] == "node-1" + assert set(call_args[2]) == {"kw1", "kw2"} + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["kw1"] == {"node-1"} + assert saved_table["kw2"] == {"node-1"} + patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600) + + +def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + texts = [ + Document(page_content="extract-this", metadata={"doc_id": "node-1"}), + Document(page_content="use-manual", metadata={"doc_id": "node-2"}), + ] + keyword.add_texts(texts, keywords_list=[[], ["manual"]]) + + assert keyword._update_segment_keywords.call_count == 2 + first_call = keyword._update_segment_keywords.call_args_list[0].args + second_call = keyword._update_segment_keywords.call_args_list[1].args + assert set(first_call[2]) == {"auto"} + assert second_call[2] == ["manual"] + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1)) + handler = MagicMock() + handler.extract_keywords.return_value = {"from-extractor"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})]) + + handler.extract_keywords.assert_called_once_with("content", 1) + assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"} + + +def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + assert keyword.text_exists("node-1") is False + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + assert keyword.text_exists("node-2") is True + assert keyword.text_exists("node-x") is False + + +def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"]) + keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}}) + + +def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_not_called() + keyword._save_dataset_keyword_table.assert_called_once_with(None) + + +def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + document_id = _Field("document_id") + + keyword = Jieba(_dataset(_dataset_keyword_table())) + query_stmt = _FakeQuery() + patched_runtime.session.query.return_value = query_stmt + patched_runtime.session.execute.return_value = _FakeExecuteResult( + [ + SimpleNamespace( + index_node_id="node-2", + content="segment-content", + index_node_hash="hash-2", + document_id="doc-2", + dataset_id="dataset-1", + ) + ] + ) + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) + + documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) + + assert len(query_stmt.where_calls) == 2 + assert len(documents) == 1 + assert documents[0].page_content == "segment-content" + assert documents[0].metadata["doc_id"] == "node-2" + assert documents[0].metadata["doc_hash"] == "hash-2" + + +def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime): + db_keyword = _dataset_keyword_table(data_source_type="database") + file_keyword = _dataset_keyword_table(data_source_type="object_storage") + + keyword_db = Jieba(_dataset(db_keyword)) + keyword_db.delete() + patched_runtime.storage.delete.assert_not_called() + + keyword_file = Jieba(_dataset(file_keyword)) + keyword_file.delete() + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + assert patched_runtime.session.delete.call_count == 2 + assert patched_runtime.session.commit.call_count == 2 + + +def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="database") + keyword = Jieba(_dataset(dataset_keyword_table)) + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table + assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table + patched_runtime.session.commit.assert_called_once() + + +def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="file") + keyword = Jieba(_dataset(dataset_keyword_table)) + patched_runtime.storage.exists.return_value = True + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + patched_runtime.storage.save.assert_called_once() + save_args = patched_runtime.storage.save.call_args.args + assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt" + assert isinstance(save_args[1], bytes) + + +def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime): + existing = _dataset_keyword_table( + keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}} + ) + keyword = Jieba(_dataset(existing)) + assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]} + + missing_payload = _dataset_keyword_table(keyword_table_dict=None) + keyword_with_missing_payload = Jieba(_dataset(missing_payload)) + assert keyword_with_missing_payload._get_dataset_keyword_table() == {} + + +def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime): + created_tables: list[SimpleNamespace] = [] + + def _fake_dataset_keyword_table(**kwargs): + kwargs.setdefault("keyword_table", "") + kwargs.setdefault("keyword_table_dict", None) + table = SimpleNamespace(**kwargs) + created_tables.append(table) + return table + + keyword = Jieba(_dataset(dataset_keyword_table=None)) + monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table) + monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database") + + result = keyword._get_dataset_keyword_table() + + assert result == {} + assert len(created_tables) == 1 + assert created_tables[0].dataset_id == "dataset-1" + assert created_tables[0].data_source_type == "database" + assert '"index_id":"dataset-1"' in created_tables[0].keyword_table + patched_runtime.session.add.assert_called_once_with(created_tables[0]) + patched_runtime.session.commit.assert_called_once() + + +def test_add_and_delete_ids_from_keyword_table_helpers(): + keyword = Jieba(_dataset(_dataset_keyword_table())) + keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}} + + updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"]) + assert updated["kw1"] == {"node-1", "node-3"} + assert updated["kw3"] == {"node-3"} + + deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"]) + assert "kw3" not in deleted + assert "kw1" not in deleted + assert deleted["kw2"] == {"node-2"} + + +def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + handler = MagicMock() + handler.extract_keywords.return_value = ["kw-a", "kw-b"] + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + + ranked_ids = keyword._retrieve_ids_by_query( + {"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}}, + "query", + k=1, + ) + + assert ranked_ids == ["node-2"] + + +def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) + + keyword = Jieba(_dataset(_dataset_keyword_table())) + segment = SimpleNamespace(keywords=[]) + patched_runtime.session.scalar.return_value = segment + + keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"]) + + assert segment.keywords == ["kw1", "kw2"] + patched_runtime.session.add.assert_called_once_with(segment) + patched_runtime.session.commit.assert_called_once() + + patched_runtime.session.reset_mock() + patched_runtime.session.scalar.return_value = None + + keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"]) + + patched_runtime.session.add.assert_not_called() + patched_runtime.session.commit.assert_not_called() + + +def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.create_segment_keywords("node-1", ["kw"]) + keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"]) + keyword._save_dataset_keyword_table.assert_called_once() + + keyword._save_dataset_keyword_table.reset_mock() + keyword.update_segment_keywords_index("node-2", ["kw2"]) + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None) + second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None) + + keyword.multi_create_segment_keywords( + [ + {"segment": first_segment, "keywords": ["manual"]}, + {"segment": second_segment, "keywords": []}, + ] + ) + + assert first_segment.keywords == ["manual"] + assert second_segment.keywords == ["auto"] + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["manual"] == {"node-1"} + assert saved_table["auto"] == {"node-2"} + + +def test_set_orjson_default_and_dumps_with_sets(): + assert set(set_orjson_default({"a", "b"})) == {"a", "b"} + + with pytest.raises(TypeError, match="is not JSON serializable"): + set_orjson_default(("not", "a", "set")) + + payload = {"items": {"a", "b"}} + json_payload = dumps_with_sets(payload) + decoded = json.loads(json_payload) + assert set(decoded["items"]) == {"a", "b"} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py new file mode 100644 index 0000000000..a4586c141b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py @@ -0,0 +1,142 @@ +import sys +import types +from types import SimpleNamespace + +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +class _DummyTFIDF: + def __init__(self): + self.stop_words = set() + + @staticmethod + def extract_tags(sentence: str, top_k: int | None = 20, **kwargs): + return ["alpha_beta", "during", "gamma"] + + +def _install_fake_jieba_modules( + monkeypatch, + analyse_module: types.ModuleType, + jieba_attrs: dict[str, object] | None = None, + tfidf_module: types.ModuleType | None = None, +): + jieba_module = types.ModuleType("jieba") + jieba_module.__path__ = [] + if jieba_attrs: + for key, value in jieba_attrs.items(): + setattr(jieba_module, key, value) + + jieba_module.analyse = analyse_module + analyse_module.__package__ = "jieba" + + monkeypatch.setitem(sys.modules, "jieba", jieba_module) + monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module) + if tfidf_module is not None: + monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module) + else: + monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False) + + +def test_init_uses_existing_default_tfidf(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + default_tfidf = _DummyTFIDF() + analyse_module.default_tfidf = default_tfidf + + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert handler._tfidf is default_tfidf + assert handler._tfidf.stop_words == STOPWORDS + + +def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + class _TFIDFFactory(_DummyTFIDF): + pass + + analyse_module.TFIDF = _TFIDFFactory + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _TFIDFFactory) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + tfidf_module = types.ModuleType("jieba.analyse.tfidf") + + class _ImportedTFIDF(_DummyTFIDF): + pass + + tfidf_module.TFIDF = _ImportedTFIDF + _install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _ImportedTFIDF) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1) + + assert fallback_keywords == ["two"] + + +def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]}) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["x"] + + +def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules( + monkeypatch, + analyse_module, + jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])}, + ) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["foo"] + + +def test_extract_keywords_expands_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"]) + + keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3) + + assert "alpha-beta" in keywords + assert "alpha" in keywords + assert "beta" in keywords + assert "during" in keywords + assert "gamma" in keywords + + +def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + + expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"}) + + assert "alpha-during-beta" in expanded + assert "alpha" in expanded + assert "beta" in expanded + assert "during" not in expanded diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py new file mode 100644 index 0000000000..1b1541ddd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py @@ -0,0 +1,6 @@ +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +def test_stopwords_loaded(): + assert "during" in STOPWORDS + assert "the" in STOPWORDS diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py new file mode 100644 index 0000000000..55e22aea0a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document + + +class _KeywordThatRaises(BaseKeyword): + def create(self, texts: list[Document], **kwargs): + return super().create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + return super().add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return super().text_exists(id) + + def delete_by_ids(self, ids: list[str]): + return super().delete_by_ids(ids) + + def delete(self): + return super().delete() + + def search(self, query: str, **kwargs): + return super().search(query, **kwargs) + + +class _KeywordForHelpers(BaseKeyword): + def __init__(self, dataset, existing_ids: set[str] | None = None): + super().__init__(dataset) + self._existing_ids = existing_ids or set() + + def create(self, texts: list[Document], **kwargs): + return self + + def add_texts(self, texts: list[Document], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete(self): + return None + + def search(self, query: str, **kwargs): + return [] + + +def test_abstract_methods_raise_not_implemented(): + keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1")) + + with pytest.raises(NotImplementedError): + keyword.create([]) + + with pytest.raises(NotImplementedError): + keyword.add_texts([]) + + with pytest.raises(NotImplementedError): + keyword.text_exists("doc-1") + + with pytest.raises(NotImplementedError): + keyword.delete_by_ids(["doc-1"]) + + with pytest.raises(NotImplementedError): + keyword.delete() + + with pytest.raises(NotImplementedError): + keyword.search("query") + + +def test_filter_duplicate_texts_removes_existing_doc_ids(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"}) + texts = [ + Document(page_content="keep", metadata={"doc_id": "keep"}), + Document(page_content="duplicate", metadata={"doc_id": "duplicate"}), + SimpleNamespace(page_content="without-metadata", metadata=None), + ] + + filtered = keyword._filter_duplicate_texts(texts) + + assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"] + assert any(text.metadata is None for text in filtered) + + +def test_get_uuids_returns_only_docs_with_metadata(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1")) + texts = [ + Document(page_content="doc-1", metadata={"doc_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "doc-2"}), + SimpleNamespace(page_content="doc-3", metadata=None), + ] + + assert keyword._get_uuids(texts) == ["doc-1", "doc-2"] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py new file mode 100644 index 0000000000..0d969a3270 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py @@ -0,0 +1,84 @@ +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.keyword.keyword_type import KeyWordType +from core.rag.models.document import Document + + +def test_get_keyword_factory_returns_jieba_factory(monkeypatch): + fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba") + + class FakeJieba: + pass + + fake_module.Jieba = FakeJieba + monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module) + + assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba + + +def test_get_keyword_factory_raises_for_unsupported_type(): + with pytest.raises(ValueError, match="Keyword store unsupported is not supported"): + Keyword.get_keyword_factory("unsupported") + + +def test_keyword_initialization_uses_configured_factory(monkeypatch): + dataset = SimpleNamespace(id="dataset-1") + fake_processor = MagicMock() + + monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA) + monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor)) + + keyword = Keyword(dataset) + + assert keyword._keyword_processor is fake_processor + + +def test_keyword_methods_forward_to_processor(): + processor = MagicMock() + processor.text_exists.return_value = True + processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})] + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = processor + + docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})] + keyword.create(docs, foo="bar") + keyword.add_texts(docs, batch=True) + assert keyword.text_exists("doc-1") is True + keyword.delete_by_ids(["doc-1"]) + keyword.delete() + assert keyword.search("query", top_k=1) == processor.search.return_value + + processor.create.assert_called_once_with(docs, foo="bar") + processor.add_texts.assert_called_once_with(docs, batch=True) + processor.text_exists.assert_called_once_with("doc-1") + processor.delete_by_ids.assert_called_once_with(["doc-1"]) + processor.delete.assert_called_once() + processor.search.assert_called_once_with("query", top_k=1) + + +def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes(): + class Processor: + value = 1 + + @staticmethod + def custom(): + return "ok" + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = Processor() + + assert keyword.custom() == "ok" + + with pytest.raises(AttributeError): + _ = keyword.value + + keyword._keyword_processor = None + with pytest.raises(AttributeError): + _ = keyword.missing_method diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py new file mode 100644 index 0000000000..8c1e4e478b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -0,0 +1,1174 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, call, patch +from uuid import uuid4 + +import pytest + +from core.rag.datasource import retrieval_service as retrieval_service_module +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset + + +def create_mock_document( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + This helper function standardizes document creation across tests, + ensuring consistent structure and reducing code duplication. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + + Example: + >>> doc = create_mock_document("Python is great", "doc1", score=0.95) + >>> assert doc.metadata["score"] == 0.95 + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + # Merge additional metadata if provided + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +class _ImmediateFuture: + def __init__(self, exception: Exception | None = None) -> None: + self._exception = exception + self.cancel_called = False + + def exception(self) -> Exception | None: + return self._exception + + def cancel(self) -> None: + self.cancel_called = True + + +class _ImmediateExecutor: + def __init__(self) -> None: + self.futures: list[_ImmediateFuture] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + try: + fn(*args, **kwargs) + future = _ImmediateFuture() + except Exception as exc: # pragma: no cover - only for defensive parity with Future semantics + future = _ImmediateFuture(exc) + self.futures.append(future) + return future + + +class _FakeExecuteScalarResult: + def __init__(self, data: list) -> None: + self._data = data + + def all(self) -> list: + return self._data + + +class _FakeExecuteResult: + def __init__(self, data: list) -> None: + self._data = data + + def scalars(self) -> _FakeExecuteScalarResult: + return _FakeExecuteScalarResult(self._data) + + +class _FakeSummaryQuery: + def __init__(self, summaries: list) -> None: + self._summaries = summaries + + def filter(self, *args, **kwargs): + return self + + def all(self) -> list: + return self._summaries + + +class _FakeSession: + def __init__(self, execute_payloads: list[list], summaries: list) -> None: + self._payloads = list(execute_payloads) + self._summaries = summaries + + def execute(self, stmt): + data = self._payloads.pop(0) if self._payloads else [] + return _FakeExecuteResult(data) + + def query(self, model): + return _FakeSummaryQuery(self._summaries) + + +class _FakeSessionContext: + def __init__(self, session: _FakeSession) -> None: + self._session = session + + def __enter__(self) -> _FakeSession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class _SimpleRetrievalChildChunk: + def __init__(self, id: str, content: str, score: float, position: int) -> None: + self.id = id + self.content = content + self.score = score + self.position = position + + +class _SimpleRetrievalSegment: + def __init__( + self, + segment, + child_chunks: list[_SimpleRetrievalChildChunk] | None = None, + score: float | None = None, + files: list[dict[str, str | int]] | None = None, + summary: str | None = None, + ) -> None: + self.segment = segment + self.child_chunks = child_chunks + self.score = score + self.files = files + self.summary = summary + + +class TestRetrievalServiceInternals: + @pytest.fixture + def internal_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = "dataset-id" + dataset.tenant_id = "tenant-id" + dataset.is_multimodal = False + dataset.doc_form = IndexStructureType.PARENT_CHILD_INDEX + return dataset + + @pytest.fixture + def internal_flask_app(self): + app = MagicMock() + app.app_context.return_value.__enter__ = Mock() + app.app_context.return_value.__exit__.return_value = False + return app + + def test_retrieve_with_attachment_ids_only(self, monkeypatch, internal_dataset): + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset", return_value=internal_dataset), + patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") as mock_retrieve, + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + def side_effect( + flask_app, + retrieval_method, + dataset, + all_documents, + exceptions, + query=None, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + ): + all_documents.append(create_mock_document(f"content-{attachment_id}", attachment_id or "none", 0.9)) + + mock_retrieve.side_effect = side_effect + + results = RetrievalService.retrieve( + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset_id=internal_dataset.id, + query="", + attachment_ids=["att-1", "att-2"], + ) + + assert len(results) == 2 + assert {doc.metadata["doc_id"] for doc in results} == {"att-1", "att-2"} + assert mock_retrieve.call_count == 2 + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + mock_validate.return_value = "validated-condition" + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, + ) + + assert results == expected_documents + mock_validate.assert_called_once() + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition="validated-condition", + ) + + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_returns_empty_when_dataset_not_found(self, mock_scalar): + mock_scalar.return_value = None + + results = RetrievalService.external_retrieve(dataset_id="missing", query="q") + + assert results == [] + + @patch("core.rag.datasource.retrieval_service.Session") + def test_get_dataset_queries_by_id(self, mock_session_class): + expected_dataset = Mock(spec=Dataset) + mock_session = Mock() + mock_session.query.return_value.where.return_value.first.return_value = expected_dataset + mock_session_class.return_value.__enter__.return_value = mock_session + + with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())): + result = RetrievalService._get_dataset("dataset-123") + + assert result == expected_dataset + mock_session.query.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_success(self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.return_value = [create_mock_document("keyword-content", "kw-1", 0.91)] + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "with quotes"', + top_k=5, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + keyword_instance.search.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_dataset_missing(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.side_effect = RuntimeError("keyword failed") + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["keyword failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [create_mock_document("vector-content", "vec-1", 0.7)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + document_ids_filter=["doc-1"], + query_type=QueryType.TEXT_QUERY, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_vector.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_non_multimodal_returns_early( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-1", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == [] + assert exceptions == [] + vector_instance.search_by_file.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_with_vision_reranking( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + reranked_docs = [create_mock_document("image-content-reranked", "img-doc", 0.97)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = True + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + model_manager.check_model_support_vision.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_without_vision_support( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = [create_mock_document("unused", "unused", 0.1)] + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = False + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == original_docs + assert exceptions == [] + processor_instance.invoke.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_with_reranking_non_multimodal( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("vector-content", "vec-doc", 0.62)] + reranked_docs = [create_mock_document("vector-content-reranked", "vec-doc", 0.89)] + + vector_instance = Mock() + vector_instance.search_by_vector.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_appends_exception_when_vector_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.side_effect = RuntimeError("vector failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == [] + assert exceptions == ["vector failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = [create_mock_document("fulltext", "ft-1", 0.68)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "x"', + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_full_text.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_with_reranking( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("fulltext", "ft-1", 0.68)] + reranked_docs = [create_mock_document("fulltext-reranked", "ft-1", 0.9)] + + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_dataset_not_found(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.side_effect = RuntimeError("fulltext failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["fulltext failed"] + + def test_format_retrieval_documents_with_empty_input_returns_empty_list(self): + assert RetrievalService.format_retrieval_documents([]) == [] + + def test_format_retrieval_documents_without_document_id_returns_empty_list(self): + documents = [Document(page_content="content", metadata={"doc_id": "doc-1", "score": 0.4}, provider="dify")] + + assert RetrievalService.format_retrieval_documents(documents) == [] + + def test_format_retrieval_documents_with_parent_child_summary_and_attachments(self, monkeypatch): + dataset_doc_parent = SimpleNamespace( + id="doc-parent", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + dataset_doc_text = SimpleNamespace(id="doc-text", doc_form="paragraph", dataset_id="dataset-id") + dataset_doc_parent_summary = SimpleNamespace( + id="doc-parent-summary", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + + dataset_query = Mock() + dataset_query.where.return_value.options.return_value.all.return_value = [ + dataset_doc_parent, + dataset_doc_text, + dataset_doc_parent_summary, + ] + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query)) + monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) + monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) + + input_documents = [ + Document( + page_content="child node content", + metadata={"document_id": "doc-parent", "doc_id": "child-node-1", "score": 0.7}, + provider="dify", + ), + Document( + page_content="parent image", + metadata={ + "document_id": "doc-parent", + "doc_id": "attach-node-1", + "doc_type": DocType.IMAGE, + "score": 0.8, + }, + provider="dify", + ), + Document( + page_content="text index node", + metadata={"document_id": "doc-text", "doc_id": "index-node-1", "score": 0.6}, + provider="dify", + ), + Document( + page_content="text image node", + metadata={ + "document_id": "doc-text", + "doc_id": "attach-text-1", + "doc_type": DocType.IMAGE, + "score": 0.65, + }, + provider="dify", + ), + Document( + page_content="summary candidate 1", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-1", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.9", + }, + provider="dify", + ), + Document( + page_content="summary candidate 2", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-2", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.95", + }, + provider="dify", + ), + Document( + page_content="invalid score summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-invalid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "invalid", + }, + provider="dify", + ), + Document( + page_content="valid parent summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-valid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "0.4", + }, + provider="dify", + ), + ] + + child_chunk = SimpleNamespace( + id="child-chunk-1", + segment_id="segment-parent", + index_node_id="child-node-1", + content="child details", + position=2, + ) + segment_parent = SimpleNamespace(id="segment-parent", document_id="doc-parent", index_node_id="parent-node") + segment_text = SimpleNamespace(id="segment-text", document_id="doc-text", index_node_id="index-node-1") + segment_summary = SimpleNamespace(id="segment-summary", document_id="doc-text", index_node_id="summary-node") + segment_parent_summary = SimpleNamespace( + id="segment-parent-summary", + document_id="doc-parent-summary", + index_node_id="summary-parent-node", + ) + + fake_session = _FakeSession( + execute_payloads=[ + [child_chunk], + [segment_text], + [segment_parent, segment_text], + [segment_summary, segment_parent_summary], + ], + summaries=[ + SimpleNamespace(chunk_id="segment-summary", summary_content="summary for text"), + SimpleNamespace(chunk_id="segment-parent-summary", summary_content="summary for parent"), + ], + ) + monkeypatch.setattr( + retrieval_service_module.session_factory, + "create_session", + lambda: _FakeSessionContext(fake_session), + ) + monkeypatch.setattr( + RetrievalService, + "get_segment_attachment_infos", + lambda attachment_ids, session: [ + { + "attachment_id": "attach-node-1", + "attachment_info": { + "id": "attach-node-1", + "name": "img-parent", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://parent", + "size": 11, + }, + "segment_id": "segment-parent", + }, + { + "attachment_id": "attach-text-1", + "attachment_info": { + "id": "attach-text-1", + "name": "img-text", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://text", + "size": 22, + }, + "segment_id": "segment-text", + }, + ], + ) + + result = RetrievalService.format_retrieval_documents(input_documents) + + assert len(result) == 4 + result_by_segment_id = {item.segment.id: item for item in result} + assert result_by_segment_id["segment-summary"].score == pytest.approx(0.95) + assert result_by_segment_id["segment-summary"].summary == "summary for text" + assert result_by_segment_id["segment-parent"].score == pytest.approx(0.8) + assert result_by_segment_id["segment-parent"].files is not None + assert len(result_by_segment_id["segment-parent"].child_chunks or []) == 1 + assert result_by_segment_id["segment-text"].score == pytest.approx(0.65) + assert result_by_segment_id["segment-parent-summary"].score == pytest.approx(0.4) + assert result_by_segment_id["segment-parent-summary"].summary == "summary for parent" + assert result_by_segment_id["segment-parent-summary"].child_chunks == [] + + def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): + rollback = Mock() + monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error"))) + + documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] + + with pytest.raises(RuntimeError, match="db error"): + RetrievalService.format_retrieval_documents(documents) + + rollback.assert_called_once() + + def test_retrieve_internal_returns_early_without_query_or_attachment(self, internal_dataset, internal_flask_app): + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=exceptions, + query=None, + attachment_id=None, + ) + + assert all_documents == [] + assert exceptions == [] + + def test_retrieve_internal_cancels_futures_when_future_has_exception(self, internal_dataset, internal_flask_app): + future_error = Mock() + future_error.exception.return_value = RuntimeError("future failed") + future_ok = Mock() + future_ok.exception.return_value = None + + with ( + patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor, + patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + return_value=[future_error, future_ok], + ), + ): + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = [future_error, future_ok] + mock_executor.return_value.__enter__.return_value = mock_executor_instance + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=[], + query="query", + attachment_id="file-1", + ) + + future_error.cancel.assert_called() + future_ok.cancel.assert_called() + + def test_retrieve_internal_raises_value_error_when_exceptions_exist( + self, monkeypatch, internal_dataset, internal_flask_app + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + with patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") as mock_keyword_search: + mock_keyword_search.side_effect = lambda *args, **kwargs: None + with pytest.raises(ValueError, match="keyword error"): + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.KEYWORD_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=["keyword error"], + query="query", + ) + + def test_retrieve_internal_hybrid_weighted_attachment_flow(self, monkeypatch, internal_dataset, internal_flask_app): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + text_doc = create_mock_document("text", "text-doc", 0.81) + image_doc = create_mock_document("image", "image-doc", 0.72) + fulltext_doc = create_mock_document("full", "full-doc", 0.65) + processed_doc = create_mock_document("processed", "processed-doc", 0.99) + + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") as mock_embedding_search, + patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search") as mock_fulltext, + patch("core.rag.datasource.retrieval_service.DataPostProcessor") as mock_processor_class, + ): + + def embedding_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + query_type=QueryType.TEXT_QUERY, + ): + if query_type == QueryType.IMAGE_QUERY: + all_documents.append(image_doc) + else: + all_documents.append(text_doc) + + mock_embedding_search.side_effect = embedding_side_effect + + def fulltext_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + ): + all_documents.append(fulltext_doc) + + mock_fulltext.side_effect = fulltext_side_effect + processor_instance = Mock() + processor_instance.invoke.return_value = [processed_doc] + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.HYBRID_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=[], + query="query", + attachment_id="file-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + top_k=3, + ) + + assert len(all_documents) == 4 + assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_info_success(self, mock_sign): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1") + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = binding + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result == { + "attachment_info": { + "id": "upload-1", + "name": "file-name", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + mock_sign.assert_called_once_with("upload-1", "png") + + def test_get_segment_attachment_info_returns_none_when_binding_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = None + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self): + upload_query = Mock() + upload_query.where.return_value.first.return_value = None + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self): + upload_query = Mock() + upload_query.where.return_value.all.return_value = [] + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + def test_get_segment_attachment_infos_returns_empty_when_bindings_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_infos_success(self, mock_sign): + upload_file_1 = SimpleNamespace( + id="upload-1", + name="file-1", + extension="png", + mime_type="image/png", + size=42, + ) + upload_file_2 = SimpleNamespace( + id="upload-2", + name="file-2", + extension="jpg", + mime_type="image/jpeg", + size=99, + ) + binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1") + + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [binding] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session) + + assert result == [ + { + "attachment_id": "upload-1", + "attachment_info": { + "id": "upload-1", + "name": "file-1", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + ] + mock_sign.assert_has_calls( + [ + call("upload-1", "png"), + call("upload-2", "jpg"), + ] + ) + assert mock_sign.call_count == 2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py new file mode 100644 index 0000000000..e063a49f22 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py @@ -0,0 +1,74 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module +from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory + + +def test_validate_distance_function_accepts_supported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + assert factory._validate_distance_function("cosine") == "cosine" + assert factory._validate_distance_function("euclidean") == "euclidean" + + +def test_validate_distance_function_rejects_unsupported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + with pytest.raises(ValueError, match="Invalid distance function"): + factory._validate_distance_function("dot_product") + + +def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" + + +def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", + index_struct_dict=None, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + vector_cls.assert_called_once() + assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2" + assert dataset.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py new file mode 100644 index 0000000000..545565cdf4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -0,0 +1,133 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig +from core.rag.models.document import Document + + +def test_init_prefers_openapi_when_api_config_is_provided(): + api_config = AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls: + vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None) + + assert vector.analyticdb_vector == "openapi_runner" + openapi_cls.assert_called_once_with("COLLECTION", api_config) + + +def test_init_uses_sql_implementation_when_api_config_is_missing(): + sql_config = AnalyticdbVectorBySqlConfig( + host="localhost", + port=5432, + account="account", + account_password="password", + min_connection=1, + max_connection=2, + namespace="dify", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls: + vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config) + + assert vector.analyticdb_vector == "sql_runner" + sql_cls.assert_called_once_with("COLLECTION", sql_config) + + +def test_init_raises_when_both_configs_are_missing(): + with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"): + AnalyticdbVector("COLLECTION", api_config=None, sql_config=None) + + +def test_vector_methods_delegate_to_underlying_implementation(): + runner = MagicMock() + runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})] + runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})] + runner.text_exists.return_value = True + + vector = AnalyticdbVector.__new__(AnalyticdbVector) + vector.analyticdb_vector = runner + + texts = [Document(page_content="hello", metadata={"doc_id": "d1"})] + vector.create(texts=texts, embeddings=[[0.1, 0.2]]) + vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]]) + assert vector.text_exists("d1") is True + vector.delete_by_ids(["d1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value + assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value + vector.delete() + + runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) + runner.delete_by_ids.assert_called_once_with(["d1"]) + runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") + runner.delete.assert_called_once() + + +def test_get_type_is_analyticdb(): + vector = AnalyticdbVector.__new__(AnalyticdbVector) + assert vector.get_type() == "analyticdb" + + +def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "auto_collection" + assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig) + assert args[2] is None + assert dataset.index_struct is not None + + +def test_factory_builds_sql_config_when_host_is_present(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "existing" + assert args[1] is None + assert isinstance(args[2], AnalyticdbVectorBySqlConfig) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py new file mode 100644 index 0000000000..45777774d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -0,0 +1,384 @@ +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from core.rag.models.document import Document + + +def _request_class(name: str): + class _Request: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + _Request.__name__ = name + return _Request + + +def _install_openapi_stubs(monkeypatch): + gpdb_package = types.ModuleType("alibabacloud_gpdb20160503") + gpdb_package.__path__ = [] + gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models") + for class_name in [ + "InitVectorDatabaseRequest", + "DescribeNamespaceRequest", + "CreateNamespaceRequest", + "DescribeCollectionRequest", + "CreateCollectionRequest", + "UpsertCollectionDataRequestRows", + "UpsertCollectionDataRequest", + "QueryCollectionDataRequest", + "DeleteCollectionDataRequest", + "DeleteCollectionRequest", + ]: + setattr(gpdb_models, class_name, _request_class(class_name)) + + class _Client: + def __init__(self, config): + self.config = config + + gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client") + gpdb_client.Client = _Client + gpdb_package.models = gpdb_models + + tea_openapi = types.ModuleType("alibabacloud_tea_openapi") + tea_openapi.__path__ = [] + tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models") + + class OpenApiConfig: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + tea_openapi_models.Config = OpenApiConfig + tea_openapi.models = tea_openapi_models + + tea_package = types.ModuleType("Tea") + tea_package.__path__ = [] + tea_exceptions = types.ModuleType("Tea.exceptions") + + class TeaError(Exception): + def __init__(self, status_code=None, **kwargs): + super().__init__("TeaException") + status_code = kwargs.get("statusCode", status_code) + self.statusCode = status_code + self.status_code = status_code + + tea_exceptions.TeaException = TeaError + tea_package.exceptions = tea_exceptions + + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models) + monkeypatch.setitem(sys.modules, "Tea", tea_package) + monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions) + + return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig) + + +def _config() -> AnalyticdbVectorOpenAPIConfig: + return AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("access_key_id", "", "ANALYTICDB_KEY_ID"), + ("access_key_secret", "", "ANALYTICDB_KEY_SECRET"), + ("region_id", "", "ANALYTICDB_REGION_ID"), + ("instance_id", "", "ANALYTICDB_INSTANCE_ID"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"), + ], +) +def test_openapi_config_validation(field, value, error_message): + values = _config().model_dump() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorOpenAPIConfig.model_validate(values) + + +def test_openapi_config_to_client_params(): + config = _config() + params = config.to_analyticdb_client_params() + + assert params["access_key_id"] == "ak" + assert params["access_key_secret"] == "sk" + assert params["region_id"] == "cn-hangzhou" + assert params["read_timeout"] == 60000 + + +def test_init_creates_openapi_client_and_runs_initialize(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + initialize_mock = MagicMock() + monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock) + + vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config()) + + assert vector._collection_name == "collection_1" + assert isinstance(vector._client_config, stubs.OpenApiConfig) + assert vector._client_config.user_agent == "dify" + assert vector._client_config.access_key_id == "ak" + assert vector._client.config is vector._client_config + initialize_mock.assert_called_once_with() + + +def test_initialize_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + vector._create_namespace_if_not_exists.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + vector._create_namespace_if_not_exists.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_initialize_vector_database_calls_openapi_client(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._initialize_vector_database() + + request = vector._client.init_vector_database.call_args.args[0] + assert request.dbinstance_id == "instance-1" + assert request.region_id == "cn-hangzhou" + assert request.manager_account == "account" + assert request.manager_account_password == "password" + + +def test_create_namespace_creates_when_namespace_not_found(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404) + + vector._create_namespace_if_not_exists() + + vector._client.create_namespace.assert_called_once() + + +def test_create_namespace_raises_on_unexpected_api_error(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create namespace"): + vector._create_namespace_if_not_exists() + + +def test_create_namespace_noop_when_namespace_exists(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._create_namespace_if_not_exists() + + vector._client.describe_namespace.assert_called_once() + vector._client.create_namespace.assert_not_called() + + +def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.create_collection.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.describe_collection.assert_not_called() + vector._client.create_collection.assert_not_called() + + +def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create collection collection_1"): + vector._create_collection_if_not_exists(embedding_dimension=512) + + +def test_openapi_add_delete_and_search_methods(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + documents = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + embeddings = [[0.1, 0.2], [0.2, 0.3]] + vector.add_texts(documents, embeddings) + + upsert_request = vector._client.upsert_collection_data.call_args.args[0] + assert upsert_request.collection == "collection_1" + assert len(upsert_request.rows) == 1 + + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()])) + ) + assert vector.text_exists("d1") is True + + vector.delete_by_ids(["d1", "d2"]) + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')" + + vector.delete_by_metadata_field("document_id", "doc-1") + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'" + + match_high = SimpleNamespace( + score=0.9, + metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"}, + values=SimpleNamespace(value=[1.0, 2.0]), + ) + match_low = SimpleNamespace( + score=0.1, + metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"}, + values=SimpleNamespace(value=[3.0, 4.0]), + ) + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high])) + ) + + docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + assert len(docs_by_vector) == 1 + assert docs_by_vector[0].page_content == "high" + assert docs_by_vector[0].metadata["score"] == 0.9 + + docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2) + assert len(docs_by_text) == 1 + assert docs_by_text[0].page_content == "high" + + +def test_text_exists_returns_false_when_matches_empty(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[])) + ) + + assert vector.text_exists("missing-id") is False + + +def test_openapi_delete_success(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector.delete() + vector._client.delete_collection.assert_called_once() + + +def test_openapi_delete_propagates_errors(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.delete_collection.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.delete() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py new file mode 100644 index 0000000000..8f1206696b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -0,0 +1,427 @@ +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import psycopg2.errors +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( + AnalyticdbVectorBySql, + AnalyticdbVectorBySqlConfig, +) +from core.rag.models.document import Document + + +def _config_values() -> dict: + return { + "host": "localhost", + "port": 5432, + "account": "account", + "account_password": "password", + "min_connection": 1, + "max_connection": 2, + "namespace": "dify", + } + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("host", "", "ANALYTICDB_HOST"), + ("port", 0, "ANALYTICDB_PORT"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"), + ("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"), + ], +) +def test_sql_config_required_fields(field, value, error_message): + values = _config_values() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_sql_config_rejects_min_connection_greater_than_max_connection(): + values = _config_values() + values["min_connection"] = 10 + values["max_connection"] = 2 + + with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_initialize_skips_when_cache_exists(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + sql_module.redis_client.set.assert_called_once() + + +def test_create_connection_pool_uses_psycopg2_pool(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + pool_instance = MagicMock() + monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance)) + + pool = vector._create_connection_pool() + + assert pool is pool_instance + sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once() + + +def test_get_cursor_context_manager_handles_connection_lifecycle(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + cursor = MagicMock() + connection = MagicMock() + connection.cursor.return_value = cursor + pool = MagicMock() + pool.getconn.return_value = connection + vector.pool = pool + + with vector._get_cursor() as cur: + assert cur is cursor + + cursor.close.assert_called_once() + connection.commit.assert_called_once() + pool.putconn.assert_called_once_with(connection) + + +def test_add_texts_inserts_only_documents_with_metadata(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id") + monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock()) + + docs = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]]) + + execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args + assert execute_args[0] is cursor + assert len(execute_args[2]) == 1 + + +def test_text_exists_returns_true_and_false_based_on_query_result(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.fetchone.return_value = ("row",) + assert vector.text_exists("d1") is True + + cursor.fetchone.return_value = None + assert vector.text_exists("d1") is False + + +def test_delete_by_ids_handles_empty_input_and_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_ids(["d1"]) + + +def test_delete_by_metadata_field_handles_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_metadata_field("document_id", "doc-1") + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_vector_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k) + + +def test_search_by_vector_returns_documents_above_threshold(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}), + ("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content 1" + assert docs[0].metadata["score"] == 0.8 + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_full_text_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=invalid_top_k) + + +def test_search_by_full_text_returns_documents(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + assert docs[0].page_content == "content 1" + + +def test_delete_drops_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete() + + cursor.execute.assert_called_once() + + +def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch): + config = AnalyticdbVectorBySqlConfig(**_config_values()) + created_pool = MagicMock() + + monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock()) + monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool)) + + vector = AnalyticdbVectorBySql("My_Collection", config) + + assert vector._collection_name == "my_collection" + assert vector.table_name == "dify.my_collection" + assert vector.databaseName == "knowledgebase" + assert vector.pool is created_pool + + +def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + bootstrap_cursor.execute.side_effect = RuntimeError("database already exists") + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + + def _execute(sql, *args, **kwargs): + if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql: + raise RuntimeError("already exists") + + worker_cursor.execute.side_effect = _execute + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + vector._initialize_vector_database() + + bootstrap_cursor.close.assert_called_once() + bootstrap_connection.close.assert_called_once() + vector._create_connection_pool.assert_called_once() + assert any( + "CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0] + for call in worker_cursor.execute.call_args_list + ) + assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list) + + +def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable") + + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + with pytest.raises(RuntimeError, match="Failed to create zhparser extension"): + vector._initialize_vector_database() + + worker_connection.rollback.assert_called_once() + + +def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + vector._create_collection_if_not_exists(embedding_dimension=3) + + assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) + assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) + sql_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + cursor.execute.side_effect = RuntimeError("permission denied") + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + with pytest.raises(RuntimeError, match="permission denied"): + vector._create_collection_if_not_exists(embedding_dimension=3) + + +def test_delete_methods_raise_when_error_is_not_missing_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.execute.side_effect = RuntimeError("unexpected delete failure") + with pytest.raises(RuntimeError, match="unexpected delete failure"): + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("unexpected metadata failure") + with pytest.raises(RuntimeError, match="unexpected metadata failure"): + vector.delete_by_metadata_field("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py new file mode 100644 index 0000000000..c46c3d5e4b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -0,0 +1,542 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_pymochow_modules(): + pymochow = types.ModuleType("pymochow") + pymochow.__path__ = [] + pymochow_auth = types.ModuleType("pymochow.auth") + pymochow_auth.__path__ = [] + pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials") + pymochow_configuration = types.ModuleType("pymochow.configuration") + pymochow_exception = types.ModuleType("pymochow.exception") + pymochow_model = types.ModuleType("pymochow.model") + pymochow_model.__path__ = [] + pymochow_model_database = types.ModuleType("pymochow.model.database") + pymochow_model_enum = types.ModuleType("pymochow.model.enum") + pymochow_model_schema = types.ModuleType("pymochow.model.schema") + pymochow_model_table = types.ModuleType("pymochow.model.table") + + class _SimpleObject: + def __init__(self, *args, **kwargs): + self.args = args + for key, value in kwargs.items(): + setattr(self, key, value) + + class ServerError(Exception): + def __init__(self, code): + super().__init__(f"server error {code}") + self.code = code + + class ServerErrCode: + TABLE_NOT_EXIST = 1001 + DB_ALREADY_EXIST = 1002 + + class IndexType: + __members__ = {"HNSW": "HNSW"} + + class MetricType: + __members__ = {"IP": "IP"} + + class IndexState: + NORMAL = "NORMAL" + + class TableState: + NORMAL = "NORMAL" + + class InvertedIndexAnalyzer: + DEFAULT_ANALYZER = "DEFAULT_ANALYZER" + + class InvertedIndexParseMode: + COARSE_MODE = "COARSE_MODE" + + class InvertedIndexFieldAttribute: + ANALYZED = "ANALYZED" + + class FieldType: + STRING = "STRING" + TEXT = "TEXT" + JSON = "JSON" + FLOAT_VECTOR = "FLOAT_VECTOR" + + pymochow.MochowClient = _SimpleObject + pymochow_credentials.BceCredentials = _SimpleObject + pymochow_configuration.Configuration = _SimpleObject + pymochow_exception.ServerError = ServerError + pymochow_model_database.Database = _SimpleObject + + pymochow_model_enum.FieldType = FieldType + pymochow_model_enum.IndexState = IndexState + pymochow_model_enum.IndexType = IndexType + pymochow_model_enum.MetricType = MetricType + pymochow_model_enum.ServerErrCode = ServerErrCode + pymochow_model_enum.TableState = TableState + + for cls_name in [ + "AutoBuildRowCountIncrement", + "Field", + "FilteringIndex", + "HNSWParams", + "InvertedIndex", + "InvertedIndexParams", + "Schema", + "VectorIndex", + ]: + setattr(pymochow_model_schema, cls_name, _SimpleObject) + pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer + pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute + pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode + + for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]: + setattr(pymochow_model_table, cls_name, _SimpleObject) + + pymochow.auth = pymochow_auth + pymochow.model = pymochow_model + pymochow_auth.bce_credentials = pymochow_credentials + pymochow_model.database = pymochow_model_database + pymochow_model.enum = pymochow_model_enum + pymochow_model.schema = pymochow_model_schema + pymochow_model.table = pymochow_model_table + + modules = { + "pymochow": pymochow, + "pymochow.auth": pymochow_auth, + "pymochow.auth.bce_credentials": pymochow_credentials, + "pymochow.configuration": pymochow_configuration, + "pymochow.exception": pymochow_exception, + "pymochow.model": pymochow_model, + "pymochow.model.database": pymochow_model_database, + "pymochow.model.enum": pymochow_model_enum, + "pymochow.model.schema": pymochow_model_schema, + "pymochow.model.table": pymochow_model_table, + } + return modules + + +@pytest.fixture +def baidu_module(monkeypatch): + for name, module in _build_fake_pymochow_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + import core.rag.datasource.vdb.baidu.baidu_vector as module + + return importlib.reload(module) + + +def test_baidu_config_validation(baidu_module): + values = { + "endpoint": "https://example.com", + "account": "account", + "api_key": "key", + "database": "database", + } + config = baidu_module.BaiduConfig.model_validate(values) + assert config.endpoint == "https://example.com" + + for key, error_message in [ + ("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"), + ("account", "BAIDU_VECTOR_DB_ACCOUNT"), + ("api_key", "BAIDU_VECTOR_DB_API_KEY"), + ("database", "BAIDU_VECTOR_DB_DATABASE"), + ]: + invalid = dict(values) + invalid[key] = "" + with pytest.raises(ValueError, match=error_message): + baidu_module.BaiduConfig.model_validate(invalid) + + +def test_get_search_result_handles_metadata_and_threshold(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace( + rows=[ + {"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9}, + {"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4}, + {"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95}, + ] + ) + + docs = vector._get_search_res(response, score_threshold=0.8) + + assert len(docs) == 2 + assert docs[0].page_content == "doc1" + assert docs[0].metadata["score"] == 0.9 + assert docs[1].page_content == "doc3" + + +def test_delete_by_ids_and_delete_by_metadata_field(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + vector._collection_name = "collection_1" + + vector.delete_by_ids([]) + table.delete.assert_not_called() + + vector.delete_by_ids(["id1", "id2"]) + table.delete.assert_called_once() + + table.delete.reset_mock() + vector.delete_by_metadata_field("source", 'abc"def') + delete_filter = table.delete.call_args.kwargs["filter"] + assert delete_filter == 'metadata["source"] = "abc\\"def"' + + +def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + + vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + vector.delete() + + vector._db.drop_table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector.delete() + + +def test_init_database_uses_existing_or_creates_when_missing(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + + vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")] + vector._client.database.return_value = "existing_db" + assert vector._init_database() == "existing_db" + + vector._client.list_databases.return_value = [] + vector._client.database.return_value = "created_db" + vector._client.create_database.side_effect = None + assert vector._init_database() == "created_db" + + vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST) + assert vector._init_database() == "created_db" + + +def test_table_existed_checks_table_access(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._db.table.return_value = MagicMock() + + assert vector._table_existed() is True + + vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + assert vector._table_existed() is False + + vector._db.table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector._table_existed() + + +def test_search_methods_delegate_to_database_table(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})]) + + table = MagicMock() + vector._db.table.return_value = table + table.search.return_value = "vector_result" + table.bm25_search.return_value = "bm25_result" + + result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + + assert result1 == vector._get_search_res.return_value + assert result2 == vector._get_search_res.return_value + assert vector._get_search_res.call_count == 2 + + +def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection" + assert dataset.index_struct is not None + + +def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch): + init_client = MagicMock(return_value="client") + init_database = MagicMock(return_value="database") + monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client) + monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database) + + config = baidu_module.BaiduConfig( + endpoint="https://example.com", + account="account", + api_key="key", + database="db", + ) + vector = baidu_module.BaiduVector(collection_name="my_collection", config=config) + + assert vector.get_type() == baidu_module.VectorType.BAIDU + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection" + assert vector._client == "client" + assert vector._db == "database" + + vector._create_table = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="p1", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_table.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_rows(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]]) + + assert table.upsert.call_count == 1 + inserted_rows = table.upsert.call_args.kwargs["rows"] + assert len(inserted_rows) == 2 + + +def test_add_texts_batches_more_than_batch_size(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"}) + for idx in range(1001) + ] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + vector.add_texts(docs, embeddings) + + assert table.upsert.call_count == 2 + assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000 + assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1 + + +def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + table.query.return_value = SimpleNamespace(code=0) + assert vector.text_exists("id-1") is True + + table.query.return_value = SimpleNamespace(code=1) + assert vector.text_exists("id-1") is False + + table.query.return_value = None + assert vector.text_exists("id-1") is False + + +def test_get_search_result_handles_invalid_metadata_json(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}]) + + docs = vector._get_search_res(response, score_threshold=0.1) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.7 + assert "document_id" not in docs[0].metadata + + +def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch): + credentials = MagicMock(return_value="credentials") + configuration = MagicMock(return_value="configuration") + client_cls = MagicMock(return_value="client") + monkeypatch.setattr(baidu_module, "BceCredentials", credentials) + monkeypatch.setattr(baidu_module, "Configuration", configuration) + monkeypatch.setattr(baidu_module, "MochowClient", client_cls) + + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + + client = vector._init_client(config) + + assert client == "client" + credentials.assert_called_once_with("account", "key") + configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + client_cls.assert_called_once_with("configuration") + + +def test_init_database_raises_for_unknown_create_database_error(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + vector._client.list_databases.return_value = [] + vector._client.create_database.side_effect = baidu_module.ServerError(9999) + + with pytest.raises(baidu_module.ServerError): + vector._init_database() + + +def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + table = MagicMock() + table.state = baidu_module.TableState.NORMAL + vector._db.describe_table.return_value = table + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock()) + + # Cached table skips all work. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Existing table also skips creation. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + vector._table_existed.return_value = True + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Create table when cache is empty and table does not exist. + vector._table_existed.return_value = False + vector._create_table(3) + vector._db.create_table.assert_called_once() + baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600) + table.rebuild_index.assert_called_once_with(vector.vector_index) + vector._wait_for_index_ready.assert_called_once_with(table, 3600) + + +def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + vector._client_config = SimpleNamespace( + index_type="INVALID", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_table(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "INVALID" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_table(3) + + +def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + vector._db.describe_table.return_value = SimpleNamespace(state="CREATING") + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301])) + + with pytest.raises(TimeoutError, match="Table creation timeout"): + vector._create_table(3) + + +def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py new file mode 100644 index 0000000000..44427b7d87 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py @@ -0,0 +1,199 @@ +import importlib +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_chroma_modules(): + chroma = types.ModuleType("chromadb") + chroma.DEFAULT_TENANT = "default_tenant" + chroma.DEFAULT_DATABASE = "default_database" + + class Settings: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class QueryResult(UserDict): + pass + + class _Collection: + def __init__(self): + self.upsert = MagicMock() + self.delete = MagicMock() + self.query = MagicMock() + self.get = MagicMock(return_value={}) + + class _Client: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.collection = _Collection() + self.get_or_create_collection = MagicMock(return_value=self.collection) + self.delete_collection = MagicMock() + + chroma.Settings = Settings + chroma.QueryResult = QueryResult + chroma.HttpClient = _Client + return chroma + + +@pytest.fixture +def chroma_module(monkeypatch): + fake_chroma = _build_fake_chroma_modules() + monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) + import core.rag.datasource.vdb.chroma.chroma_vector as module + + return importlib.reload(module) + + +def test_chroma_config_to_params_builds_expected_payload(chroma_module): + config = chroma_module.ChromaConfig( + host="localhost", + port=8000, + tenant="tenant-1", + database="db-1", + auth_provider="provider", + auth_credentials="credentials", + ) + + params = config.to_chroma_params() + + assert params["host"] == "localhost" + assert params["port"] == 8000 + assert params["tenant"] == "tenant-1" + assert params["database"] == "db-1" + assert params["ssl"] is False + assert params["settings"].chroma_client_auth_provider == "provider" + assert params["settings"].chroma_client_auth_credentials == "credentials" + + +def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock()) + + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create_collection("collection_1") + + vector._client.get_or_create_collection.assert_called_once_with("collection_1") + chroma_module.redis_client.set.assert_called_once() + + +def test_create_with_empty_texts_is_noop(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create([], []) + vector._client.get_or_create_collection.assert_not_called() + + +def test_create_with_texts_creates_collection_and_upserts(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._client.get_or_create_collection.assert_called() + vector._client.collection.upsert.assert_called_once() + + +def test_delete_methods_and_text_exists(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + + vector.delete_by_ids([]) + vector._client.collection.delete.assert_not_called() + + vector.delete_by_ids(["id-1"]) + vector._client.collection.delete.assert_called_with(ids=["id-1"]) + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}}) + + vector._client.collection.get.return_value = {"ids": ["id-1"]} + assert vector.text_exists("id-1") is True + vector._client.collection.get.return_value = {} + assert vector.text_exists("id-2") is False + + vector.delete() + vector._client.delete_collection.assert_called_once_with("collection_1") + + +def test_search_by_vector_handles_empty_results(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []} + + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + +def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = { + "ids": [["id-1", "id-2"]], + "documents": [["doc high", "doc low"]], + "metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]], + "distances": [[0.1, 0.8]], + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc high" + assert docs[0].metadata["score"] == 0.9 + + +def test_search_by_full_text_returns_empty_list(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + assert vector.search_by_full_text("query") == [] + + +def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch): + factory = chroma_module.ChromaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None) + + with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py new file mode 100644 index 0000000000..0ce5c04dd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py @@ -0,0 +1,927 @@ +import importlib +import queue +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickzetta_module(): + clickzetta = types.ModuleType("clickzetta") + + class _FakeCursor: + def __init__(self): + self.execute = MagicMock() + self.executemany = MagicMock() + self.fetchall = MagicMock(return_value=[]) + self.fetchone = MagicMock(return_value=(0,)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _FakeConnection: + def __init__(self): + self.cursor_obj = _FakeCursor() + + def cursor(self): + return self.cursor_obj + + def close(self): + return None + + def connect(**_kwargs): + return _FakeConnection() + + clickzetta.connect = connect + return clickzetta + + +@pytest.fixture +def clickzetta_module(monkeypatch): + monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) + import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.ClickzettaConfig( + username="username", + password="password", + instance="instance", + service="service", + workspace="workspace", + vcluster="cluster", + schema_name="dify", + ) + + +@pytest.mark.parametrize( + ("field", "error_message"), + [ + ("username", "CLICKZETTA_USERNAME"), + ("password", "CLICKZETTA_PASSWORD"), + ("instance", "CLICKZETTA_INSTANCE"), + ("service", "CLICKZETTA_SERVICE"), + ("workspace", "CLICKZETTA_WORKSPACE"), + ("vcluster", "CLICKZETTA_VCLUSTER"), + ("schema_name", "CLICKZETTA_SCHEMA"), + ], +) +def test_clickzetta_config_validation(clickzetta_module, field, error_message): + values = _config(clickzetta_module).model_dump() + values[field] = "" + with pytest.raises(ValueError, match=error_message): + clickzetta_module.ClickzettaConfig.model_validate(values) + + +def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1") + assert parsed["doc_id"] == "row-1" + assert parsed["document_id"] == "doc-1" + + parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2") + assert parsed_double["doc_id"] == "row-2" + assert parsed_double["document_id"] == "doc-2" + + parsed_fallback = vector._parse_metadata("not-json", "row-3") + assert parsed_fallback["doc_id"] == "row-3" + assert parsed_fallback["document_id"] == "row-3" + + +def test_safe_doc_id_and_vector_format_helpers(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3" + assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF" + assert vector._safe_doc_id("ab c;\n") == "abc" + assert len(vector._safe_doc_id("a" * 300)) == 255 + + +def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_not_found(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_not_found + assert vector._table_exists() is False + + @contextmanager + def _ctx_other_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("permission denied") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_other_error + assert vector._table_exists() is False + + +def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.text_exists("doc-1") is False + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchone.return_value = (1,) + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + assert vector.text_exists("doc-1") is True + + +def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + + vector.delete_by_ids([]) + vector._execute_write.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + vector.delete_by_ids(["doc-1"]) + vector._execute_write.assert_not_called() + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_not_called() + + +def test_search_short_circuit_behaviors(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + vector._config.enable_inverted_index = False + assert vector.search_by_full_text("query", top_k=2) == [] + + +def test_search_by_like_returns_documents_with_default_score(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content" + assert docs[0].metadata["score"] == 0.5 + + +def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch): + factory = clickzetta_module.ClickzettaVectorFactory() + dataset = SimpleNamespace(id="dataset-1") + + monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance") + + with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "collection" + + +def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch): + clickzetta_module.ClickzettaConnectionPool._instance = None + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + + pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance() + pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance() + key = pool_1._get_config_key(_config(clickzetta_module)) + + assert pool_1 is pool_2 + assert "username:instance:service:workspace:cluster:dify" in key + + +def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + connection = MagicMock() + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr( + clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection]) + ) + pool._configure_connection = MagicMock() + + created = pool._create_connection(config) + + assert created is connection + assert clickzetta_module.clickzetta.connect.call_count == 2 + pool._configure_connection.assert_called_once_with(connection) + + +def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + pool._create_connection(config) + + +def test_connection_pool_configure_and_validate_connection(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection = MagicMock() + connection.cursor.return_value = cursor + + pool._configure_connection(connection) + assert cursor.execute.call_count >= 2 + assert pool._is_connection_valid(connection) is True + + bad_connection = MagicMock() + bad_connection.cursor.side_effect = RuntimeError("bad connection") + assert pool._is_connection_valid(bad_connection) is False + monkeypatch.undo() + + +def test_connection_pool_configure_connection_swallows_errors(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + connection = MagicMock() + connection.cursor.side_effect = RuntimeError("cannot configure") + + pool._configure_connection(connection) + monkeypatch.undo() + + +def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + key = pool._get_config_key(config) + + created_connection = MagicMock() + pool._create_connection = MagicMock(return_value=created_connection) + first = pool.get_connection(config) + assert first is created_connection + + reusable_connection = MagicMock() + pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())] + pool._is_connection_valid = MagicMock(return_value=True) + reused = pool.get_connection(config) + assert reused is reusable_connection + + expired_connection = MagicMock() + pool._pools[key] = [(expired_connection, 0.0)] + pool._is_connection_valid = MagicMock(return_value=False) + monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0)) + pool.get_connection(config) + expired_connection.close.assert_called_once() + + random_connection = MagicMock() + pool._is_connection_valid = MagicMock(return_value=True) + pool.return_connection(config, random_connection) + assert len(pool._pools[key]) == 1 + + pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)] + pool._connection_timeout = 10 + pool._cleanup_expired_connections() + assert len(pool._pools[key]) == 1 + + unknown_pool = MagicMock() + pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool) + unknown_pool.close.assert_called_once() + + pool.shutdown() + assert pool._shutdown is True + + +def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True)) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + self.started = False + + def start(self): + self.started = True + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + + assert pool._cleanup_thread.started is True + pool._cleanup_expired_connections.assert_called_once() + + +def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch): + pool = MagicMock() + pool.get_connection.return_value = "conn" + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool)) + monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock()) + + vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module)) + assert vector._table_name == "my_collection" + + assert vector._get_connection() == "conn" + vector._return_connection("conn") + pool.return_connection.assert_called_with(vector._config, "conn") + + with vector.get_connection_context() as conn: + assert conn == "conn" + assert pool.return_connection.call_count >= 2 + + assert vector.get_type() == "clickzetta" + assert vector._ensure_connection() == "conn" + + +def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch): + class _Thread: + def __init__(self, target, daemon): + self.target = target + self.daemon = daemon + self.started = 0 + + def start(self): + self.started += 1 + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_thread = None + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._init_write_queue() + clickzetta_module.ClickzettaVector._init_write_queue() + assert clickzetta_module.ClickzettaVector._write_thread.started == 1 + + result_queue_ok = queue.Queue() + result_queue_fail = queue.Queue() + clickzetta_module.ClickzettaVector._write_queue = queue.Queue() + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok)) + clickzetta_module.ClickzettaVector._write_queue.put( + (lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail) + ) + clickzetta_module.ClickzettaVector._write_queue.put(None) + clickzetta_module.ClickzettaVector._write_worker() + + assert result_queue_ok.get() == (True, 2) + failed = result_queue_fail.get() + assert failed[0] is False + assert isinstance(failed[1], RuntimeError) + + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + clickzetta_module.ClickzettaVector._write_queue = None + with pytest.raises(RuntimeError, match="Write queue not initialized"): + vector._execute_write(lambda: None) + + class _ImmediateSuccessQueue: + def put(self, task): + func, args, kwargs, result_q = task + result_q.put((True, func(*args, **kwargs))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue() + assert vector._execute_write(lambda x: x * 2, 3) == 6 + + class _ImmediateFailQueue: + def put(self, task): + _, _, _, result_q = task + result_q.put((False, ValueError("write failed"))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue() + with pytest.raises(ValueError, match="write failed"): + vector._execute_write(lambda: None) + + +def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_exists(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_exists + assert vector._table_exists() is True + + vector._execute_write = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="content", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._execute_write.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_table_and_indexes_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._create_vector_index = MagicMock() + vector._create_inverted_index = MagicMock() + + vector._table_exists = MagicMock(return_value=True) + vector._create_table_and_indexes([[0.1, 0.2]]) + vector._create_vector_index.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._create_table_and_indexes([[0.1, 0.2, 0.3]]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_called_once() + + vector._config.enable_inverted_index = False + vector._create_vector_index.reset_mock() + vector._create_inverted_index.reset_mock() + vector._create_table_and_indexes([]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_not_called() + + +def test_create_vector_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show index failed"), None] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("already exists")] + cursor.fetchall.return_value = [] + vector._create_vector_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("unexpected")] + cursor.fetchall.return_value = [] + with pytest.raises(RuntimeError, match="unexpected"): + vector._create_vector_index(cursor) + + +def test_create_inverted_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show failed"), None] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [ + None, + RuntimeError("already has index"), + None, + ] + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("other create failure")] + cursor.fetchall.return_value = [] + vector._create_inverted_index(cursor) + + +def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.batch_size = 2 + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id)) + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2"}), + Document(page_content="doc-3", metadata={"doc_id": "id-3"}), + ] + vectors = [[0.1], [0.2], [0.3]] + + vector.add_texts([], []) + vector._execute_write.assert_not_called() + + added_ids = vector.add_texts(docs, vectors) + assert added_ids == ["id-1", "id-2", "id-3"] + assert vector._execute_write.call_count == 2 + assert vector._execute_write.call_args_list[0].args == ( + vector._insert_batch, + docs[:2], + vectors[:2], + ["id-1", "id-2"], + 0, + 2, + 2, + ) + assert vector._execute_write.call_args_list[1].args == ( + vector._insert_batch, + docs[2:], + vectors[2:], + ["id-3"], + 2, + 2, + 2, + ) + + vector._insert_batch([], [], [], 0, 2, 1) + vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1) + + bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}}) + good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [bad_doc, good_doc], + [[0.1, 0.2], [0.3, 0.4]], + ["id-bad", "id-good"], + 0, + 2, + 1, + ) + + @contextmanager + def _ctx_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.executemany.side_effect = RuntimeError("insert failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_error + with pytest.raises(RuntimeError, match="insert failed"): + vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1) + + monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id") + vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector) + assert vector._safe_doc_id("") == "generated-id" + assert vector._safe_doc_id("!!!") == "generated-id" + + +def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._table_exists = MagicMock(return_value=True) + + vector.delete_by_ids(["id-1", "id-2"]) + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl + + vector._execute_write.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl + + vector._safe_doc_id = MagicMock(side_effect=lambda x: x) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._delete_by_ids_impl(["id-1", "id-2"]) + vector._delete_by_metadata_field_impl("document_id", "doc-1") + + +def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.vector_distance_function = "cosine_distance" + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + cosine_docs = vector.search_by_vector( + [0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"} + ) + assert cosine_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._config.vector_distance_function = "l2_distance" + l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2) + + +def test_search_by_full_text_success_and_fallback(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx_success(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'), + ("seg-2", "content-2", "invalid-json"), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_success + docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1}) + assert len(docs) == 2 + assert docs[0].metadata["score"] == 1.0 + assert docs[1].metadata["doc_id"] == "seg-2" + + @contextmanager + def _ctx_failure(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("full text failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_failure + vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})]) + fallback_docs = vector.search_by_full_text("query", top_k=1) + assert fallback_docs == vector._search_by_like.return_value + + +def test_search_by_like_missing_table_and_delete_table(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=False) + assert vector._search_by_like("query", top_k=1) == [] + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector.delete() + + +def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._pools = {} + pool._pool_locks = {} + pool._max_pool_size = 1 + pool._connection_timeout = 10 + pool._lock = clickzetta_module.threading.Lock() + pool._shutdown = False + + config = _config(clickzetta_module) + key = pool._get_config_key(config) + pool._pools[key] = [(MagicMock(), 1.0)] + pool._pool_locks[key] = clickzetta_module.threading.Lock() + pool._is_connection_valid = MagicMock(return_value=False) + + conn = MagicMock() + pool.return_connection(config, conn) + conn.close.assert_called_once() + + pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)] + pool._cleanup_expired_connections() + pool.shutdown() + assert pool._shutdown is True + + +def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + + def _cleanup_then_fail(): + pool._shutdown = True + raise RuntimeError("cleanup failed") + + pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail) + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + + def start(self): + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + pool._cleanup_expired_connections.assert_called_once() + + +def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1") + assert parsed_non_dict["doc_id"] == "row-1" + assert parsed_non_dict["document_id"] == "row-1" + + parsed_none = vector._parse_metadata(None, "row-2") + assert parsed_none["doc_id"] == "row-2" + assert parsed_none["document_id"] == "row-2" + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_worker() + + class _BadQueue: + def get(self, timeout): + clickzetta_module.ClickzettaVector._shutdown = True + raise RuntimeError("queue failed") + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = _BadQueue() + clickzetta_module.ClickzettaVector._write_worker() + + +def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + cursor.execute.side_effect = [ + None, + RuntimeError("already has index with the same type cannot create inverted index"), + None, + ] + + vector._create_inverted_index(cursor) + + vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value)) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor_obj = MagicMock() + cursor_obj.__enter__.return_value = cursor_obj + cursor_obj.__exit__.return_value = None + connection.cursor.return_value = cursor_obj + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [SimpleNamespace(page_content="content", metadata="not-a-dict")], + [[0.1, 0.2]], + ["doc-1"], + 0, + 1, + 1, + ) + + +def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.enable_inverted_index = True + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_full_text("query") == [] + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", "[1,2,3]"), + ("seg-2", "content-2", None), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector.search_by_full_text("query") + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[1].metadata["doc_id"] == "seg-2" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py new file mode 100644 index 0000000000..9fea187615 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py @@ -0,0 +1,364 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_couchbase_modules(): + couchbase = types.ModuleType("couchbase") + couchbase_auth = types.ModuleType("couchbase.auth") + couchbase_cluster = types.ModuleType("couchbase.cluster") + couchbase_management = types.ModuleType("couchbase.management") + couchbase_management_search = types.ModuleType("couchbase.management.search") + couchbase_options = types.ModuleType("couchbase.options") + couchbase_vector = types.ModuleType("couchbase.vector_search") + couchbase_search = types.ModuleType("couchbase.search") + + class PasswordAuthenticator: + def __init__(self, user, password): + self.user = user + self.password = password + + class ClusterOptions: + def __init__(self, auth): + self.auth = auth + + class SearchOptions: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorQuery: + def __init__(self, field, vector, top_k): + self.field = field + self.vector = vector + self.top_k = top_k + + class VectorSearch: + @staticmethod + def from_vector_query(vector_query): + return {"vector_query": vector_query} + + class QueryStringQuery: + def __init__(self, query): + self.query = query + + class SearchRequest: + @staticmethod + def create(payload): + return {"payload": payload} + + class SearchIndex: + def __init__(self, name, params, source_name): + self.name = name + self.params = params + self.source_name = source_name + + class _QueryResult: + def __init__(self, rows=None): + self._rows = rows or [] + + def execute(self): + return self + + def __iter__(self): + return iter(self._rows) + + class _SearchIter: + def __init__(self, rows=None): + self._rows = rows or [] + + def rows(self): + return self._rows + + class _Collection: + def __init__(self): + self.upsert = MagicMock(return_value=True) + + class _SearchIndexManager: + def __init__(self): + self.upsert_index = MagicMock() + + class _Scope: + def __init__(self): + self._collection = _Collection() + self._search_index_manager = _SearchIndexManager() + self.search = MagicMock(return_value=_SearchIter()) + + def collection(self, _name): + return self._collection + + def search_indexes(self): + return self._search_index_manager + + class _CollectionManager: + def __init__(self): + self.create_collection = MagicMock() + self.drop_collection = MagicMock() + self.get_all_scopes = MagicMock(return_value=[]) + + class _Bucket: + def __init__(self): + self._scope = _Scope() + self._collections = _CollectionManager() + + def scope(self, _scope_name): + return self._scope + + def collections(self): + return self._collections + + class Cluster: + def __init__(self, connection_string, options): + self.connection_string = connection_string + self.options = options + self._bucket = _Bucket() + self.wait_until_ready = MagicMock() + self.query = MagicMock(return_value=_QueryResult()) + + def bucket(self, _name): + return self._bucket + + couchbase_auth.PasswordAuthenticator = PasswordAuthenticator + couchbase_cluster.Cluster = Cluster + couchbase_management_search.SearchIndex = SearchIndex + couchbase_options.ClusterOptions = ClusterOptions + couchbase_options.SearchOptions = SearchOptions + couchbase_vector.VectorQuery = VectorQuery + couchbase_vector.VectorSearch = VectorSearch + couchbase_search.QueryStringQuery = QueryStringQuery + couchbase_search.SearchRequest = SearchRequest + + couchbase.search = couchbase_search + couchbase.management = couchbase_management + + return { + "couchbase": couchbase, + "couchbase.auth": couchbase_auth, + "couchbase.cluster": couchbase_cluster, + "couchbase.management": couchbase_management, + "couchbase.management.search": couchbase_management_search, + "couchbase.options": couchbase_options, + "couchbase.vector_search": couchbase_vector, + "couchbase.search": couchbase_search, + } + + +@pytest.fixture +def couchbase_module(monkeypatch): + for name, module in _build_fake_couchbase_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.couchbase.couchbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.CouchbaseConfig( + connection_string="couchbase://localhost", + user="user", + password="pass", + bucket_name="bucket", + scope_name="scope", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("connection_string", "", "CONNECTION_STRING is required"), + ("user", "", "COUCHBASE_USER is required"), + ("password", "", "COUCHBASE_PASSWORD is required"), + ("bucket_name", "", "COUCHBASE_PASSWORD is required"), + ("scope_name", "", "COUCHBASE_SCOPE_NAME is required"), + ], +) +def test_couchbase_config_validation(couchbase_module, field, value, message): + values = _config(couchbase_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + couchbase_module.CouchbaseConfig.model_validate(values) + + +def test_init_sets_cluster_handles(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + assert vector._bucket_name == "bucket" + assert vector._scope_name == "scope" + vector._cluster.wait_until_ready.assert_called_once() + + +def test_create_and_create_collection_branches(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector) + vector._collection_name = "collection_1" + vector._client_config = _config(couchbase_module) + vector._scope_name = "scope" + vector._bucket_name = "bucket" + vector._bucket = MagicMock() + vector._scope = MagicMock() + vector._collection_exists = MagicMock(return_value=False) + vector.add_texts = MagicMock() + + monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c") + vector._create_collection = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock()) + + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(vector_length=2, uuid="uuid-1") + vector._bucket.collections().create_collection.assert_not_called() + + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._collection_exists = MagicMock(return_value=True) + vector._create_collection(vector_length=2, uuid="uuid-2") + vector._bucket.collections().create_collection.assert_not_called() + + vector._collection_exists = MagicMock(return_value=False) + vector._create_collection(vector_length=3, uuid="uuid-3") + + vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1") + vector._scope.search_indexes().upsert_index.assert_called_once() + search_index = vector._scope.search_indexes().upsert_index.call_args.args[0] + assert search_index.name == "collection_1_search" + assert ( + search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"] + == 3 + ) + couchbase_module.redis_client.set.assert_called_once() + + +def test_collection_exists_get_type_and_add_texts(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is True + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is False + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert vector._scope.collection("collection_1").upsert.call_count == 2 + assert vector.get_type() == couchbase_module.VectorType.COUCHBASE + + +def test_query_delete_helpers(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}])) + assert vector.text_exists("id-1") is True + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([])) + assert vector.text_exists("id-2") is False + + query_result = MagicMock() + query_result.execute.return_value = None + vector._cluster.query.return_value = query_result + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_document_id("id-1") + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._cluster.query.call_count >= 3 + + vector._cluster.query.side_effect = RuntimeError("delete failed") + vector.delete_by_ids(["id-3"]) + + +def test_search_methods_and_format_metadata(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9) + row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["document_id"] == "d-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector._scope.search.side_effect = RuntimeError("search error") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_vector([0.1], top_k=1) + + vector._scope.search.side_effect = None + row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3]) + docs = vector.search_by_full_text("hello", top_k=1) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == "x" + + vector._scope.search.side_effect = RuntimeError("full text failed") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_full_text("hello", top_k=1) + + assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2} + + +def test_delete_collection_and_factory(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + scopes = [ + SimpleNamespace(collections=[SimpleNamespace(name="other")]), + SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]), + ] + vector._bucket.collections().get_all_scopes.return_value = scopes + + vector.delete() + vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1") + + factory = couchbase_module.CouchbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + couchbase_module, + "current_app", + SimpleNamespace( + config={ + "COUCHBASE_CONNECTION_STRING": "couchbase://localhost", + "COUCHBASE_USER": "user", + "COUCHBASE_PASSWORD": "pass", + "COUCHBASE_BUCKET_NAME": "bucket", + "COUCHBASE_SCOPE_NAME": "scope", + } + ), + ) + + with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py new file mode 100644 index 0000000000..edd29a4649 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py @@ -0,0 +1,121 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0"}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_ja_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + + importlib.reload(base_module) + return importlib.reload(ja_module) + + +def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_not_called() + + +def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2, 0.3]], [{}]) + + vector._client.indices.create.assert_called_once() + kwargs = vector._client.indices.create.call_args.kwargs + assert kwargs["index"] == "test" + assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3 + elasticsearch_ja_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_ja_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_called_once() + + +def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch): + factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + elasticsearch_ja_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_HOST": "localhost", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + + with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py new file mode 100644 index 0000000000..9ecf0caa24 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py @@ -0,0 +1,405 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}}) + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), + delete=MagicMock(), + exists=MagicMock(return_value=False), + create=MagicMock(), + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + + return importlib.reload(module) + + +def _regular_config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "username": "elastic", + "password": "secret", + "verify_certs": False, + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +def _cloud_config(module, **overrides): + values = { + "use_cloud": True, + "cloud_url": "https://cloud.example:9243", + "api_key": "api-key", + "verify_certs": True, + "ca_certs": "/tmp/ca.pem", + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("values", "message"), + [ + ({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"), + ({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"), + ({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"), + ({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"), + ({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"), + ({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"), + ], +) +def test_elasticsearch_config_validation(elasticsearch_module, values, message): + with pytest.raises(ValidationError, match=message): + elasticsearch_module.ElasticSearchConfig.model_validate(values) + + +def test_init_client_cloud_configuration(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + result = vector._init_client(_cloud_config(elasticsearch_module)) + + assert result is client + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://cloud.example:9243"] + assert kwargs["api_key"] == "api-key" + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + +def test_init_client_regular_https_and_http_fallback(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client( + _regular_config( + elasticsearch_module, + host="https://es.example", + port=9443, + verify_certs=True, + ca_certs="/tmp/ca.pem", + ) + ) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://es.example:9443"] + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200)) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["http://es.internal:9200"] + assert "verify_certs" not in kwargs + + +def test_init_client_connection_failures(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + + client = MagicMock() + client.ping.return_value = False + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client): + with pytest.raises(ConnectionError, match="Failed to connect"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object( + elasticsearch_module, + "Elasticsearch", + side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"), + ): + with pytest.raises(ConnectionError, match="Vector database connection error"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")): + with pytest.raises(ConnectionError, match="initialization failed"): + vector._init_client(_regular_config(elasticsearch_module)) + + +def test_init_get_version_and_check_version(elasticsearch_module): + with ( + patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client, + patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version, + patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version, + ): + vector = elasticsearch_module.ElasticSearchVector( + "collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"] + ) + + init_client.assert_called_once() + get_version.assert_called_once() + check_version.assert_called_once() + assert vector._attributes == ["doc_id"] + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._client = MagicMock() + vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}} + assert vector._get_version() == "8.13.2" + + vector._version = "7.17.0" + with pytest.raises(ValueError, match="greater than 8.0.0"): + vector._check_version() + + vector._version = "8.0.0" + vector._check_version() + + +def test_crud_methods_and_get_type(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock()) + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection_1") + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._client.delete.call_count == 2 + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "d1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "d2") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1") + assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH + + +def test_search_by_vector_and_full_text(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.8, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-a", + elasticsearch_module.Field.VECTOR: [0.1], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-b", + elasticsearch_module.Field.VECTOR: [0.2], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + knn = vector._client.search.call_args.kwargs["knn"] + assert knn["k"] == 2 + assert knn["num_candidates"] == 3 + assert "filter" in knn + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "text-hit", + elasticsearch_module.Field.VECTOR: [0.3], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + query = vector._client.search.call_args.kwargs["query"] + assert "bool" in query + + +def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + mappings = vector._client.indices.create.call_args.kwargs["mappings"] + assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2 + elasticsearch_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + elasticsearch_module.redis_client.set.assert_called_once() + + +def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch): + factory = elasticsearch_module.ElasticSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": False, + "ELASTICSEARCH_HOST": "es-host", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + "ELASTICSEARCH_VERIFY_CERTS": False, + } + ), + ) + + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + assert result_1 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION" + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic", + "ELASTICSEARCH_API_KEY": "api-key", + "ELASTICSEARCH_VERIFY_CERTS": True, + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + assert result_2 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is True + assert cfg.cloud_url == "https://cloud.elastic" + assert dataset_without_index.index_struct is not None + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": None, + "ELASTICSEARCH_HOST": "fallback-host", + "ELASTICSEARCH_PORT": 9201, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert cfg.host == "fallback-host" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py new file mode 100644 index 0000000000..5d9e744ded --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py @@ -0,0 +1,371 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_hologres_modules(): + holo_module = types.ModuleType("holo_search_sdk") + holo_types_module = types.ModuleType("holo_search_sdk.types") + + holo_types_module.BaseQuantizationType = str + holo_types_module.DistanceType = str + holo_types_module.TokenizerType = str + + def _connect(**kwargs): + client = MagicMock() + client.kwargs = kwargs + client.connect = MagicMock() + client.check_table_exist = MagicMock(return_value=False) + client.open_table = MagicMock(return_value=MagicMock()) + client.execute = MagicMock(return_value=[]) + client.drop_table = MagicMock() + return client + + holo_module.connect = MagicMock(side_effect=_connect) + + return { + "holo_search_sdk": holo_module, + "holo_search_sdk.types": holo_types_module, + } + + +@pytest.fixture +def hologres_module(monkeypatch): + for name, module in _build_fake_hologres_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.hologres.hologres_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.HologresVectorConfig( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema_name="public", + tokenizer="jieba", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config HOLOGRES_HOST is required"), + ("database", "", "config HOLOGRES_DATABASE is required"), + ("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"), + ], +) +def test_hologres_config_validation(hologres_module, field, value, message): + values = _valid_config(hologres_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + hologres_module.HologresVectorConfig.model_validate(values) + + +def test_init_client_and_get_type(hologres_module): + vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module)) + + hologres_module.holo.connect.assert_called_once_with( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema="public", + ) + vector._client.connect.assert_called_once() + assert vector.table_name == "embedding_collection_one" + assert vector.get_type() == hologres_module.VectorType.HOLOGRES + + +def test_create_delegates_collection_creation_and_upsert(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result is None + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_returns_empty_for_empty_documents(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + assert vector.add_texts([], []) == [] + vector._client.open_table.assert_not_called() + + +def test_add_texts_batches_and_serializes_metadata(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + table = vector._client.open_table.return_value + documents = [ + Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"}) + for i in range(100) + ] + documents.append(SimpleNamespace(page_content="doc-100", metadata=None)) + embeddings = [[float(i)] for i in range(len(documents))] + + ids = vector.add_texts(documents, embeddings) + + assert ids[:2] == ["id-0", "id-1"] + assert ids[-1] == "" + assert len(ids) == 101 + assert vector._client.open_table.call_count == 2 + assert table.upsert_multi.call_count == 2 + first_call = table.upsert_multi.call_args_list[0].kwargs + second_call = table.upsert_multi.call_args_list[1].kwargs + assert first_call["index_column"] == "id" + assert first_call["column_names"] == ["id", "text", "meta", "embedding"] + assert first_call["update_columns"] == ["text", "meta", "embedding"] + assert len(first_call["values"]) == 100 + assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"} + assert second_call["values"][0][0] == "" + assert second_call["values"][0][2] == "{}" + + +def test_text_exists_handles_missing_and_present_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + vector._client.execute.return_value = [(1,)] + + assert vector.text_exists("seg-1") is False + assert vector.text_exists("seg-1") is True + vector._client.execute.assert_called_once() + + +def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +def test_delete_by_ids_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + vector.delete_by_ids([]) + vector._client.check_table_exist.assert_not_called() + + vector._client.check_table_exist.return_value = False + vector.delete_by_ids(["id-1"]) + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_ids(["id-1", "id-2"]) + vector._client.execute.assert_called_once() + + +def test_delete_by_metadata_field_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_called_once() + + +def test_search_by_vector_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_vector([0.1, 0.2]) == [] + + +def test_search_by_vector_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + query = MagicMock() + table.search_vector.return_value = query + query.select.return_value = query + query.limit.return_value = query + query.where.return_value = query + query.fetchall.return_value = [ + (0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'), + (0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["doc-1"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.8) + table.search_vector.assert_called_once() + query.where.assert_called_once() + + +def test_search_by_full_text_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_full_text("query") == [] + + +def test_search_by_full_text_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + search_query = MagicMock() + table.search_text.return_value = search_query + search_query.limit.return_value = search_query + search_query.where.return_value = search_query + search_query.fetchall.return_value = [ + ("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95), + ("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"]) + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.95) + assert docs[1].metadata["score"] == pytest.approx(0.7) + table.search_text.assert_called_once() + search_query.where.assert_called_once() + + +def test_delete_handles_existing_and_missing_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + + vector.delete() + vector._client.drop_table.assert_not_called() + + vector.delete() + vector._client.drop_table.assert_called_once_with(vector.table_name) + + +def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection(3) + + vector._client.check_table_exist.assert_not_called() + hologres_module.redis_client.set.assert_not_called() + + +def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, False, True] + table = vector._client.open_table.return_value + + vector._create_collection(3) + + vector._client.execute.assert_called_once() + table.set_vector_index.assert_called_once_with( + column="embedding", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + use_reorder=True, + ) + table.create_text_index.assert_called_once_with( + index_name="ft_idx_collection_one", + column="text", + tokenizer="jieba", + ) + hologres_module.redis_client.set.assert_called_once() + + +def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False] + [False] * 15 + + with pytest.raises(RuntimeError, match="was not ready after 30s"): + vector._create_collection(3) + + hologres_module.redis_client.set.assert_not_called() + + +def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch): + factory = hologres_module.HologresVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400) + + with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection" + generated_config = vector_cls.call_args_list[1].kwargs["config"] + assert generated_config.host == "127.0.0.1" + assert generated_config.database == "dify" + assert generated_config.access_key_id == "ak" + assert json.loads(dataset_without_index.index_struct) == { + "type": hologres_module.VectorType.HOLOGRES, + "vector_store": {"class_prefix": "generated_collection"}, + } diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py new file mode 100644 index 0000000000..9d23dfcf63 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py @@ -0,0 +1,243 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def huawei_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass") + + +def test_create_ssl_context(huawei_module): + ctx = huawei_module.create_ssl_context() + assert ctx.check_hostname is False + assert ctx.verify_mode == huawei_module.ssl.CERT_NONE + + +def test_huawei_config_validation_and_params(huawei_module): + with pytest.raises(ValidationError, match="HOSTS is required"): + huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""}) + + config = _config(huawei_module) + params = config.to_elasticsearch_params() + assert params["hosts"] == ["http://localhost:9200"] + assert params["basic_auth"] == ("user", "pass") + + config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None) + params = config.to_elasticsearch_params() + assert "basic_auth" not in params + + +def test_init_get_type_and_add_texts(huawei_module): + vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module)) + + assert vector._collection_name == "collection" + assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_crud_methods(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once_with(index="collection", id="id-1") + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection") + + +def test_search_by_vector_and_full_text(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + }, + { + "_score": 0.1, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-b", + huawei_module.Field.VECTOR: [0.2], + huawei_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + query_body = vector._client.search.call_args.kwargs["body"] + assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2 + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + huawei_module.Field.CONTENT_KEY: "text-hit", + huawei_module.Field.VECTOR: [0.3], + huawei_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + + +def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch): + class FakeDocument: + def __init__(self, page_content, vector, metadata): + self.page_content = page_content + self.vector = vector + self.metadata = None + + monkeypatch.setattr(huawei_module, "Document", FakeDocument) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + } + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5) + + assert docs == [] + + +def test_create_and_create_collection_paths(huawei_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock()) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + + kwargs = vector._client.indices.create.call_args.kwargs + mappings = kwargs["mappings"] + assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2 + assert kwargs["settings"] == {"index.vector": True} + huawei_module.redis_client.set.assert_called_once() + + +def test_huawei_factory_branches(huawei_module, monkeypatch): + factory = huawei_module.HuaweiCloudVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass") + + with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py new file mode 100644 index 0000000000..63338ca809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py @@ -0,0 +1,412 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_iris_module(): + iris = types.ModuleType("iris") + + def connect(**_kwargs): + conn = MagicMock() + conn.cursor.return_value = MagicMock() + return conn + + iris.connect = MagicMock(side_effect=connect) + return iris + + +@pytest.fixture +def iris_module(monkeypatch): + monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) + + import core.rag.datasource.vdb.iris.iris_vector as module + + reloaded = importlib.reload(module) + reloaded._pool_instance = None + return reloaded + + +def _config(module, **overrides): + values = { + "IRIS_HOST": "localhost", + "IRIS_SUPER_SERVER_PORT": 1972, + "IRIS_USER": "user", + "IRIS_PASSWORD": "pass", + "IRIS_DATABASE": "db", + "IRIS_SCHEMA": "schema", + "IRIS_CONNECTION_URL": "url", + "IRIS_MIN_CONNECTION": 1, + "IRIS_MAX_CONNECTION": 2, + "IRIS_TEXT_INDEX": True, + "IRIS_TEXT_INDEX_LANGUAGE": "en", + } + values.update(overrides) + return module.IrisVectorConfig.model_validate(values) + + +def test_get_iris_pool_singleton(iris_module): + iris_module._pool_instance = None + cfg = _config(iris_module) + + with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls: + pool_1 = iris_module.get_iris_pool(cfg) + pool_2 = iris_module.get_iris_pool(cfg) + + assert pool_1 == "pool" + assert pool_2 == "pool" + pool_cls.assert_called_once_with(cfg) + + +@pytest.fixture +def pool_with_min_max(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn: + pool = iris_module.IrisConnectionPool(cfg) + yield pool, create_conn + + +def test_pool_initialization_respects_min_max(pool_with_min_max): + pool, create_conn = pool_with_min_max + assert len(pool._pool) == 2 + assert create_conn.call_count == 2 + + +@pytest.fixture +def pool_for_get_connection(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_get_connection_returns_existing_and_increments(pool_for_get_connection): + pool = pool_for_get_connection + conn = MagicMock() + pool._pool = [conn] + pool._in_use = 0 + assert pool.get_connection() is conn + assert pool._in_use == 1 + + +def test_get_connection_creates_new_when_empty(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = 0 + pool._create_connection = MagicMock(return_value="new-conn") + assert pool.get_connection() == "new-conn" + + +def test_get_connection_raises_when_exhausted(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = pool._max_size + with pytest.raises(RuntimeError, match="exhausted"): + pool.get_connection() + + +@pytest.fixture +def pool_for_return_connection(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_return_connection_adds_healthy(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool.return_connection(conn) + assert pool._pool[-1] is conn + assert pool._in_use == 0 + + +def test_return_connection_replaces_bad(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + bad_conn = MagicMock() + bad_cursor = MagicMock() + bad_cursor.execute.side_effect = OSError("bad") + bad_conn.cursor.return_value = bad_cursor + replacement = MagicMock() + pool._create_connection = MagicMock(return_value=replacement) + pool.return_connection(bad_conn) + bad_conn.close.assert_called_once() + assert pool._pool[-1] is replacement + assert pool._in_use == 0 + + +def test_return_connection_ignores_none(pool_for_return_connection): + pool = pool_for_return_connection + before = len(pool._pool) + pool.return_connection(None) + assert len(pool._pool) == before + + +@pytest.fixture +def pool_for_schema_and_close(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool._pool = [conn] + return pool, conn, cursor + + +def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = {"cached_schema"} + pool.ensure_schema_exists("cached_schema") + cursor.execute.assert_not_called() + + +def test_ensure_schema_exists_creates_new(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (0,) + pool.ensure_schema_exists("new_schema") + assert "new_schema" in pool._schemas_initialized + assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list) + conn.commit.assert_called_once() + + +def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (1,) + pool.ensure_schema_exists("existing_schema") + conn.commit.assert_not_called() + + +def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.execute.side_effect = RuntimeError("schema failure") + with pytest.raises(RuntimeError, match="schema failure"): + pool.ensure_schema_exists("broken_schema") + conn.rollback.assert_called() + + +def test_close_all_closes_and_resets(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + conn_2 = MagicMock() + conn_2.close.side_effect = OSError("close fail") + pool._pool = [conn, conn_2] + pool._schemas_initialized = {"x"} + pool.close_all() + assert pool._pool == [] + assert pool._in_use == 0 + assert pool._schemas_initialized == set() + + +def test_iris_vector_init_get_cursor_and_create(iris_module): + pool = MagicMock() + pool.get_connection.return_value = MagicMock() + + with patch.object(iris_module, "get_iris_pool", return_value=pool): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + assert vector.table_name == "EMBEDDING_COLLECTION" + assert vector.schema == "schema" + assert vector.get_type() == iris_module.VectorType.IRIS + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + + with vector._get_cursor() as got_cursor: + assert got_cursor is cursor + conn.commit.assert_called_once() + vector.pool.return_connection.assert_called_with(conn) + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + with pytest.raises(RuntimeError, match="boom"): + with vector._get_cursor(): + raise RuntimeError("boom") + conn.rollback.assert_called_once() + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["id-1"]) + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"] + vector._create_collection.assert_called_once_with(2) + + +def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id") + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "generated-id"] + assert cursor.execute.call_count == 2 + + cursor.fetchone.return_value = (1,) + assert vector.text_exists("id-1") is True + cursor.fetchone.return_value = None + assert vector.text_exists("id-2") is False + + vector._get_cursor = MagicMock(side_effect=RuntimeError("db down")) + assert vector.text_exists("id-3") is False + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + before = cursor.execute.call_count + vector.delete_by_ids(["id-1", "id-2"]) + assert cursor.execute.call_count == before + 1 + + vector.delete_by_metadata_field("document_id", "doc-1") + assert "meta LIKE" in cursor.execute.call_args.args[0] + + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.9), + ("id-2", "text-2", '{"document_id":"d-2"}', 0.2), + ("id-x",), + ] + docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + +def test_iris_vector_full_text_search_paths(iris_module, monkeypatch): + cfg = _config(iris_module, IRIS_TEXT_INDEX=True) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", cfg) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + cursor.execute.side_effect = None + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.7), + ("id-2", "text-2", "{}", None), + ] + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.7) + assert docs[1].metadata["score"] == pytest.approx(0.0) + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("rank failed"), None] + cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)] + docs = vector.search_by_full_text("query", top_k=1) + assert len(docs) == 1 + assert cursor.execute.call_count == 2 + + cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector_like = iris_module.IrisVector("collection", cfg_like) + vector_like._get_cursor = _cursor_ctx + + fake_libs = types.ModuleType("libs") + fake_helper = types.ModuleType("libs.helper") + fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%") + monkeypatch.setitem(sys.modules, "libs", fake_libs) + monkeypatch.setitem(sys.modules, "libs.helper", fake_helper) + + cursor.reset_mock() + cursor.execute.side_effect = None + cursor.fetchall.return_value = [] + assert vector_like.search_by_full_text("100%", top_k=1) == [] + + +def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete() + assert "DROP TABLE" in cursor.execute.call_args.args[0] + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(iris_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_called_once() + + cursor.reset_mock() + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None)) + vector.pool.ensure_schema_exists = MagicMock() + vector._create_collection(3) + assert cursor.execute.call_count == 3 + iris_module.redis_client.set.assert_called_once() + + cursor.reset_mock() + vector.config.IRIS_TEXT_INDEX = False + vector._create_collection(3) + assert cursor.execute.call_count == 2 + + factory = iris_module.IrisVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972) + monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user") + monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass") + monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema") + monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url") + monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1) + monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en") + + with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py new file mode 100644 index 0000000000..34357d5907 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py @@ -0,0 +1,394 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearch_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = SimpleNamespace( + refresh=MagicMock(), + exists=MagicMock(return_value=False), + delete=MagicMock(), + create=MagicMock(), + ) + self.bulk = MagicMock(return_value={"errors": False, "items": []}) + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.delete_by_query = MagicMock() + self.get = MagicMock(return_value={"_id": "id"}) + self.exists = MagicMock(return_value=True) + + opensearch_helpers.BulkIndexError = BulkIndexError + opensearch_helpers.bulk = MagicMock() + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.helpers = opensearch_helpers + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearch_helpers, + } + + +@pytest.fixture +def lindorm_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.lindorm.lindorm_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.LindormVectorStoreConfig( + hosts="http://localhost:9200", + username="user", + password="pass", + using_ugc=False, + request_timeout=3.0, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("hosts", None, "config URL is required"), + ("username", None, "config USERNAME is required"), + ("password", None, "config PASSWORD is required"), + ], +) +def test_lindorm_config_validation(lindorm_module, field, value, message): + values = _config(lindorm_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + lindorm_module.LindormVectorStoreConfig.model_validate(values) + + +def test_to_opensearch_params_and_init(lindorm_module): + cfg = _config(lindorm_module) + params = cfg.to_opensearch_params() + + assert params["hosts"] == "http://localhost:9200" + assert params["http_auth"] == ("user", "pass") + + vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False) + assert vector._collection_name == "collection" + assert vector.get_type() == lindorm_module.VectorType.LINDORM + + with pytest.raises(ValueError, match="routing_value"): + lindorm_module.LindormVectorStore("c", cfg, using_ugc=True) + + vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE") + assert vector_ugc._routing == "route" + + +def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock()) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + + vector.add_texts(docs, embeddings, batch_size=2, timeout=9) + + assert vector._client.bulk.call_count == 2 + actions = vector._client.bulk.call_args_list[0].args[0] + assert actions[0]["index"]["routing"] == "route" + assert actions[1][lindorm_module.ROUTING_FIELD] == "route" + vector.refresh() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_add_texts_error_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]} + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + vector._client.bulk.side_effect = RuntimeError("bulk failed") + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + +def test_metadata_lookup_and_delete_by_metadata(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + query = vector._client.search.call_args.kwargs["body"] + must_conditions = query["query"]["bool"]["must"] + assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions) + + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"]) + + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-2") + vector.delete_by_ids.assert_not_called() + + +def test_delete_by_ids_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + vector.delete_by_ids([]) + vector._client.indices.exists.assert_not_called() + + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["id-1"]) + + vector._client.indices.exists.return_value = True + vector._client.exists.side_effect = [True, False] + lindorm_module.helpers.bulk.reset_mock() + vector.delete_by_ids(["id-1", "id-2"]) + lindorm_module.helpers.bulk.assert_called_once() + actions = lindorm_module.helpers.bulk.call_args.args[1] + assert len(actions) == 1 + assert actions[0]["routing"] == "route" + + lindorm_module.helpers.bulk.reset_mock() + lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError( + errors=[ + {"delete": {"status": 404, "_id": "id-404"}}, + {"delete": {"status": 500, "_id": "id-500"}}, + ] + ) + vector._client.exists.side_effect = [True] + vector.delete_by_ids(["id-1"]) + + +def test_delete_and_text_exists(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.delete() + vector._client.delete_by_query.assert_called_once() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.indices.exists.return_value = True + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60}) + + vector._client.indices.delete.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete() + vector._client.indices.delete.assert_not_called() + + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("missing") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validation_and_success(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + with pytest.raises(ValueError, match="should be a list"): + vector.search_by_vector("bad") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, "bad"]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-b", + lindorm_module.Field.VECTOR: [0.2], + lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + call_kwargs = vector._client.search.call_args.kwargs + query = call_kwargs["body"] + assert "ext" in query + assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"] + assert call_kwargs["params"]["routing"] == "route" + + vector._client.search.side_effect = RuntimeError("search failed") + with pytest.raises(RuntimeError, match="search failed"): + vector.search_by_vector([0.1]) + + +def test_search_by_full_text_success_and_error(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1"}, + } + } + ] + } + } + + docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] + + vector._client.search.side_effect = RuntimeError("full text failed") + with pytest.raises(RuntimeError, match="full text failed"): + vector.search_by_full_text("hello") + + +def test_create_collection_paths(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + + with pytest.raises(ValueError, match="cannot be empty"): + vector.create_collection([]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"}) + vector._client.indices.create.assert_called_once() + body = vector._client.indices.create.call_args.kwargs["body"] + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf" + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine" + + vector._client.indices.create.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + +def test_lindorm_factory_branches(lindorm_module, monkeypatch): + factory = lindorm_module.LindormVectorStoreFactory() + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2") + monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={}) + embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3]) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None) + with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"): + factory.init_vector(dataset, attributes=[], embeddings=embeddings) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + + dataset_existing_plain = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False}, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings) + assert result == "vector" + assert store_cls.call_args.args[0] == "existing" + + dataset_existing_ugc = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={ + "vector_store": {"class_prefix": "ROUTING"}, + "using_ugc": True, + "dimension": 1536, + "index_type": "hnsw", + "distance_type": "l2", + }, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "ROUTING" + + dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={}) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "auto_collection" + assert dataset_new.index_struct is not None + + dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={}) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "auto_collection" + assert store_cls.call_args.kwargs["routing_value"] is None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py new file mode 100644 index 0000000000..55e7b9112e --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py @@ -0,0 +1,252 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_mo_vector_modules(): + mo_vector = types.ModuleType("mo_vector") + mo_vector.__path__ = [] + mo_vector_client = types.ModuleType("mo_vector.client") + + class MoVectorClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_full_text_index = MagicMock() + self.insert = MagicMock() + self.get = MagicMock(return_value=[]) + self.delete = MagicMock() + self.query_by_metadata = MagicMock(return_value=[]) + self.query = MagicMock(return_value=[]) + self.full_text_query = MagicMock(return_value=[]) + + mo_vector_client.MoVectorClient = MoVectorClient + mo_vector.client = mo_vector_client + return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client} + + +@pytest.fixture +def matrixone_module(monkeypatch): + for name, module in _build_fake_mo_vector_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.matrixone.matrixone_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.MatrixoneConfig( + host="localhost", + port=6001, + user="dump", + password="111", + database="dify", + metric="l2", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config host is required"), + ("port", 0, "config port is required"), + ("user", "", "config user is required"), + ("password", "", "config password is required"), + ("database", "", "config database is required"), + ], +) +def test_matrixone_config_validation(matrixone_module, field, value, message): + values = _valid_config(matrixone_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + matrixone_module.MatrixoneConfig.model_validate(values) + + +def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client.kwargs["table_name"] == "collection_1" + client.create_full_text_index.assert_called_once() + matrixone_module.redis_client.set.assert_called_once() + + +def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + client.create_full_text_index.assert_not_called() + matrixone_module.redis_client.set.assert_not_called() + + +def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = None + fake_client = MagicMock() + fake_client.get.return_value = [{"id": "seg-1"}] + vector._get_client = MagicMock(return_value=fake_client) + + exists = vector.text_exists("seg-1") + + assert exists is True + vector._get_client.assert_called_once_with(None, False) + + +def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.full_text_query.return_value = [ + SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}} + + +def test_get_type_and_create_delegate_to_add_texts(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + fake_client = MagicMock() + vector._get_client = MagicMock(return_value=fake_client) + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "matrixone" + assert result == ["seg-1"] + vector._get_client.assert_called_once_with(2, True) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + failing_client = MagicMock() + failing_client.create_full_text_index.side_effect = RuntimeError("boom") + monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client)) + + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client is failing_client + matrixone_module.redis_client.set.assert_not_called() + + +def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + # For current prod code, only docs with metadata get ids, so only two ids + assert ids == ["doc-a", "generated-uuid"] + vector.client.insert.assert_called_once() + insert_kwargs = vector.client.insert.call_args.kwargs + # All lists passed to insert should be the same length + texts = insert_kwargs["texts"] + embeddings = insert_kwargs["embeddings"] + metadatas = insert_kwargs["metadatas"] + ids_insert = insert_kwargs["ids"] + assert len(texts) == len(embeddings) == len(metadatas) == len(docs) + # ids may be shorter than docs for current prod code, but should match number of docs with metadata + assert ids_insert == ["doc-a", "generated-uuid"] + + +def test_delete_and_metadata_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")] + + vector.delete_by_ids([]) + vector.client.delete.assert_not_called() + + vector.delete_by_ids(["seg-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + vector.delete() + + assert ids == ["seg-1", "seg-2"] + assert vector.client.delete.call_count == 3 + + +def test_search_by_vector_builds_documents(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query.return_value = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}), + ] + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"]) + + assert len(docs) == 2 + assert docs[0].page_content == "doc-a" + assert docs[1].metadata["doc_id"] == "2" + assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}} + + +def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch): + factory = matrixone_module.MatrixoneVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001) + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2") + + with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index fb2ddfe162..2ac2c40d38 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -1,18 +1,414 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + import pytest from pydantic import ValidationError -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig +from core.rag.models.document import Document -def test_default_value(): +def _build_fake_pymilvus_modules(): + pymilvus = types.ModuleType("pymilvus") + pymilvus.__path__ = [] + pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client") + pymilvus_orm = types.ModuleType("pymilvus.orm") + pymilvus_orm.__path__ = [] + pymilvus_orm_types = types.ModuleType("pymilvus.orm.types") + + class MilvusError(Exception): + pass + + class MilvusClient: + def __init__(self, **kwargs): + self.init_kwargs = kwargs + self.has_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock( + return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]} + ) + self.get_server_version = MagicMock(return_value="2.5.0") + self.insert = MagicMock(return_value=[1]) + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.drop_collection = MagicMock() + self.search = MagicMock(return_value=[[]]) + self.create_collection = MagicMock() + + class IndexParams: + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + class DataType: + JSON = "JSON" + VARCHAR = "VARCHAR" + INT64 = "INT64" + SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR" + FLOAT_VECTOR = "FLOAT_VECTOR" + + class FieldSchema: + def __init__(self, name, dtype, **kwargs): + self.name = name + self.dtype = dtype + self.kwargs = kwargs + + class CollectionSchema: + def __init__(self, fields): + self.fields = fields + self.functions = [] + + def add_function(self, func): + self.functions.append(func) + + class FunctionType: + BM25 = "BM25" + + class Function: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def infer_dtype_bydata(_value): + return DataType.FLOAT_VECTOR + + pymilvus.MilvusException = MilvusError + pymilvus.MilvusClient = MilvusClient + pymilvus.IndexParams = IndexParams + pymilvus.CollectionSchema = CollectionSchema + pymilvus.DataType = DataType + pymilvus.FieldSchema = FieldSchema + pymilvus.Function = Function + pymilvus.FunctionType = FunctionType + pymilvus_milvus_client.IndexParams = IndexParams + pymilvus_orm.types = pymilvus_orm_types + pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata + + # Attach submodules for dotted imports + pymilvus.milvus_client = pymilvus_milvus_client + pymilvus.orm = pymilvus_orm + + return { + "pymilvus": pymilvus, + "pymilvus.milvus_client": pymilvus_milvus_client, + "pymilvus.orm": pymilvus_orm, + "pymilvus.orm.types": pymilvus_orm_types, + } + + +@pytest.fixture +def milvus_module(monkeypatch): + for name, module in _build_fake_pymilvus_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.milvus.milvus_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "uri": "http://localhost:19530", + "user": "root", + "password": "Milvus", + "database": "default", + "enable_hybrid_search": False, + "analyzer_params": None, + } + values.update(overrides) + return module.MilvusConfig.model_validate(values) + + +def test_config_validation_and_defaults(milvus_module): valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig.model_validate(config) + milvus_module.MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig.model_validate(valid_config) + config = milvus_module.MilvusConfig.model_validate(valid_config) assert config.database == "default" + + token_config = milvus_module.MilvusConfig.model_validate( + {"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"} + ) + assert token_config.token == "token-value" + + +def test_config_to_milvus_params(milvus_module): + config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + + params = config.to_milvus_params() + + assert params["uri"] == "http://localhost:19530" + assert params["db_name"] == "default" + assert params["analyzer_params"] == '{"tokenizer":"standard"}' + + +def test_init_client_supports_token_and_user_password(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + token_client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"} + + user_client = vector._init_client(_config(milvus_module)) + assert user_client.init_kwargs["uri"] == "http://localhost:19530" + assert user_client.init_kwargs["user"] == "root" + assert user_client.init_kwargs["password"] == "Milvus" + + +def test_init_loads_fields_when_collection_exists(milvus_module): + client = milvus_module.MilvusClient(uri="http://localhost:19530") + client.has_collection.return_value = True + client.describe_collection.return_value = { + "fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}] + } + + with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client): + with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False): + vector = milvus_module.MilvusVector("collection_1", _config(milvus_module)) + + assert "id" not in vector._fields + assert "content" in vector._fields + + +def test_load_collection_fields_from_argument_and_remote(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + vector._collection_name = "collection_1" + vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]} + + vector._load_collection_fields(["id", "metadata"]) + assert vector._fields == ["metadata"] + + vector._load_collection_fields() + assert vector._fields == ["content"] + + +def test_check_hybrid_search_support_branches(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + + vector._client_config = SimpleNamespace(enable_hybrid_search=False) + assert vector._check_hybrid_search_support() is False + + vector._client_config = SimpleNamespace(enable_hybrid_search=True) + vector._client.get_server_version.return_value = "Zilliz Cloud 2.4" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.5.1" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.4.9" + assert vector._check_hybrid_search_support() is False + + vector._client.get_server_version.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_get_type_and_create_delegate(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [SimpleNamespace(page_content="hello", metadata=None)] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "milvus" + vector.create_collection.assert_called_once() + create_args = vector.create_collection.call_args.args + assert create_args[0] == [[0.1, 0.2]] + assert create_args[1] == [{}] + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_and_raises_milvus_exception(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.insert.side_effect = [["id-1"], ["id-2"]] + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + ids = vector.add_texts(docs, embeddings) + assert ids == ["id-1", "id-2"] + assert vector._client.insert.call_count == 2 + + vector._client.insert.side_effect = milvus_module.MilvusException("insert failed") + with pytest.raises(milvus_module.MilvusException): + vector.add_texts([Document(page_content="x", metadata={})], [[0.1]]) + + +def test_get_ids_and_delete_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.query.return_value = [{"id": 1}, {"id": 2}] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2] + vector._client.query.return_value = [] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector._client.has_collection.return_value = True + vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102]) + + vector._client.delete.reset_mock() + vector._client.query.return_value = [{"id": 11}, {"id": 12}] + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12]) + + vector._client.has_collection.return_value = True + vector.delete() + vector._client.drop_collection.assert_called_once_with("collection_1", None) + + +def test_text_exists_and_field_exists(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._fields = ["content", "metadata"] + vector._client = MagicMock() + vector._client.has_collection.return_value = False + assert vector.text_exists("doc-1") is False + + vector._client.has_collection.return_value = True + vector._client.query.return_value = [{"id": 1}] + assert vector.text_exists("doc-1") is True + vector._client.query.return_value = [] + assert vector.text_exists("doc-1") is False + assert vector.field_exists("content") is True + assert vector.field_exists("unknown") is False + + +def test_process_search_results_and_search_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._fields = ["content", "metadata", "sparse_vector"] + + processed = vector._process_search_results( + [ + [ + {"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9}, + {"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2}, + ] + ], + [milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY], + score_threshold=0.5, + ) + assert len(processed) == 1 + assert processed[0].metadata["score"] == 0.9 + + vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1) + assert docs == ["doc"] + assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]' + + vector._hybrid_search_enabled = False + assert vector.search_by_full_text("query") == [] + + vector._hybrid_search_enabled = True + vector._fields = [] + assert vector.search_by_full_text("query") == [] + + vector._fields = [milvus_module.Field.SPARSE_VECTOR] + vector._process_search_results = MagicMock(return_value=["full-text-doc"]) + full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2) + assert full_text_docs == ["full-text-doc"] + assert "document_id" in vector._client.search.call_args.kwargs["filter"] + + +def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client_config = _config(milvus_module) + vector._hybrid_search_enabled = False + vector._client = MagicMock() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.has_collection.return_value = True + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + milvus_module.redis_client.set.assert_called() + + +def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client = MagicMock() + vector._client.has_collection.return_value = False + vector._load_collection_fields = MagicMock() + + vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + vector._hybrid_search_enabled = True + vector.create_collection( + embeddings=[[0.1, 0.2]], + metadatas=[{"doc_id": "1"}], + index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}}, + ) + + call_kwargs = vector._client.create_collection.call_args.kwargs + schema = call_kwargs["schema"] + index_params_obj = call_kwargs["index_params"] + field_names = [f.name for f in schema.fields] + + assert milvus_module.Field.SPARSE_VECTOR in field_names + assert len(schema.functions) == 1 + assert len(index_params_obj.indexes) == 2 + assert call_kwargs["consistency_level"] == "Session" + + +def test_factory_initializes_milvus_vector(milvus_module, monkeypatch): + factory = milvus_module.MilvusVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}') + + with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py new file mode 100644 index 0000000000..a75ba82238 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py @@ -0,0 +1,230 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickhouse_connect_module(): + clickhouse_connect = types.ModuleType("clickhouse_connect") + + class QueryResult: + def __init__(self, rows=None, named_rows=None): + self.row_count = len(rows or []) + self.result_rows = rows or [] + self._named_rows = named_rows or [] + + def named_results(self): + return self._named_rows + + class Client: + def __init__(self): + self.command = MagicMock() + self.query = MagicMock(return_value=QueryResult()) + + client = Client() + + def get_client(**_kwargs): + return client + + clickhouse_connect.get_client = get_client + clickhouse_connect.QueryResult = QueryResult + clickhouse_connect._fake_client = client + return clickhouse_connect + + +@pytest.fixture +def myscale_module(monkeypatch): + fake_module = _build_fake_clickhouse_connect_module() + monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) + + import core.rag.datasource.vdb.myscale.myscale_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="", + ) + + +def test_escape_str_replaces_backslash_and_quote(myscale_module): + escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special") + assert escaped == "text with special" + + +def test_search_raises_for_invalid_top_k(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0) + + +def test_search_builds_where_clause_for_cosine_threshold(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__( + named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}] + ) + + docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2) + + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "WHERE dist < 0.8" in sql + + +def test_delete_by_ids_short_circuits_on_empty_list(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete_by_ids([]) + vector._client.command.assert_not_called() + + +def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch): + factory = myscale_module.MyScaleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123) + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "") + + with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +def test_init_and_get_type_set_expected_defaults(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + + assert vector.get_type() == "myscale" + assert vector._vec_order == myscale_module.SortOrder.ASC + vector._client.command.assert_called_with("SET allow_experimental_object_type=1") + + +def test_create_calls_create_collection_and_add_texts(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once() + + +def test_create_collection_builds_expected_sql(myscale_module): + config = myscale_module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="tokenizer=unicode", + ) + vector = myscale_module.MyScaleVector("collection_1", config) + vector._client.command.reset_mock() + + vector._create_collection(3) + + assert vector._client.command.call_count == 2 + sql = vector._client.command.call_args_list[1].args[0] + assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql + assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql + assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql + + +def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="text-2", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="text-3", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + sql = vector._client.command.call_args.args[0] + assert "INSERT INTO dify.collection_1" in sql + assert "te xt 1" in sql + + +def test_text_exists_and_metadata_operations(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)]) + + assert vector.text_exists("id-1") is True + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._client.command.call_count >= 2 + + +def test_search_delegation_methods(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._search = MagicMock(return_value=["result"]) + + result_vector = vector.search_by_vector([0.1, 0.2], top_k=2) + result_text = vector.search_by_full_text("hello", top_k=2) + + assert result_vector == ["result"] + assert result_text == ["result"] + assert vector._search.call_count == 2 + + +def test_search_with_document_filter_and_exception(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace( + named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}] + ) + + docs = vector._search( + "distance(vector, [0.1])", + myscale_module.SortOrder.ASC, + top_k=2, + document_ids_filter=["doc-1", "doc-2"], + ) + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql + + vector._client.query.side_effect = RuntimeError("boom") + assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == [] + + +def test_delete_drops_table(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete() + + vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py new file mode 100644 index 0000000000..27d8198ec0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py @@ -0,0 +1,553 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.exc import SQLAlchemyError + +from core.rag.models.document import Document + + +def _build_fake_pyobvector_module(): + pyobvector = types.ModuleType("pyobvector") + + class VECTOR: + def __init__(self, dim): + self.dim = dim + + def l2_distance(*_args, **_kwargs): + return "l2" + + def cosine_distance(*_args, **_kwargs): + return "cosine" + + def inner_product(*_args, **_kwargs): + return "inner_product" + + class ObVecClient: + def __init__(self, **_kwargs): + self.metadata_obj = SimpleNamespace(tables={}) + self.engine = MagicMock() + self.check_table_exists = MagicMock(return_value=False) + self.perform_raw_text_sql = MagicMock() + self.prepare_index_params = MagicMock() + self.create_table_with_index_params = MagicMock() + self.refresh_metadata = MagicMock() + self.insert = MagicMock() + self.refresh_index = MagicMock() + self.get = MagicMock() + self.delete = MagicMock() + self.set_ob_hnsw_ef_search = MagicMock() + self.ann_search = MagicMock(return_value=[]) + self.drop_table_if_exist = MagicMock() + + pyobvector.VECTOR = VECTOR + pyobvector.ObVecClient = ObVecClient + pyobvector.l2_distance = l2_distance + pyobvector.cosine_distance = cosine_distance + pyobvector.inner_product = inner_product + return pyobvector + + +@pytest.fixture +def oceanbase_module(monkeypatch): + monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) + + import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.OceanBaseVectorConfig( + host="127.0.0.1", + port=2881, + user="root", + password="secret", + database="test", + enable_hybrid_search=True, + batch_size=10, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OCEANBASE_VECTOR_HOST is required"), + ("port", 0, "config OCEANBASE_VECTOR_PORT is required"), + ("user", "", "config OCEANBASE_VECTOR_USER is required"), + ("database", "", "config OCEANBASE_VECTOR_DATABASE is required"), + ], +) +def test_oceanbase_config_validation(oceanbase_module, field, value, message): + values = _config(oceanbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oceanbase_module.OceanBaseVectorConfig.model_validate(values) + + +def test_init_rejects_invalid_collection_name(oceanbase_module): + with pytest.raises(ValueError, match="Invalid collection name"): + oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module)) + + +def test_distance_to_score_for_supported_metrics(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="l2") + assert vector._distance_to_score(3.0) == pytest.approx(0.25) + + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._distance_to_score(0.2) == pytest.approx(0.8) + + vector._config = SimpleNamespace(metric_type="inner_product") + assert vector._distance_to_score(-0.2) == pytest.approx(0.2) + + +def test_get_distance_func_raises_for_unknown_metric(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="manhattan") + + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._get_distance_func() + + +def test_process_search_results_handles_json_and_score_threshold(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + rows = [ + ("doc-1", '{"doc_id":"1"}', 0.9), + ("doc-2", "not-json", 0.8), + ("doc-3", {"doc_id": "3"}, 0.3), + ] + + docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank") + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["rank"] == 0.9 + assert docs[1].metadata["rank"] == 0.8 + + +def test_search_by_vector_validates_document_id_format(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + + with pytest.raises(ValueError, match="Invalid document ID format"): + vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"]) + + +def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._hybrid_search_enabled = False + vector._collection_name = "collection_1" + + assert vector.search_by_full_text("query") == [] + + +def test_check_hybrid_search_support_uses_version_comment(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client = MagicMock() + cursor = MagicMock() + cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",) + vector._client.perform_raw_text_sql.return_value = cursor + + assert vector._check_hybrid_search_support() is True + + cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",) + assert vector._check_hybrid_search_support() is False + + +def test_init_get_type_and_field_loading(oceanbase_module): + config = _config(oceanbase_module) + config.enable_hybrid_search = False + + table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")]) + fake_client = oceanbase_module.ObVecClient() + fake_client.check_table_exists.return_value = True + fake_client.metadata_obj.tables = {"collection_1": table} + + with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client): + vector = oceanbase_module.OceanBaseVector("collection_1", config) + + assert vector.get_type() == "oceanbase" + assert vector.field_exists("text") is True + + +def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._fields = [] + vector._client = MagicMock() + vector._client.metadata_obj.tables = {} + + vector._load_collection_fields() + assert vector._fields == [] + + vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))} + vector._load_collection_fields() + assert vector._fields == [] + + +def test_create_delegates_to_collection_and_insert(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector._vec_dim == 2 + vector._create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = False + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection() + vector._client.check_table_exists.assert_not_called() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.check_table_exists.return_value = True + vector._create_collection() + vector.delete.assert_not_called() + + +def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 3 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + None, + None, + ] + index_params = MagicMock() + vector._client.prepare_index_params.return_value = index_params + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._create_collection() + + vector.delete.assert_called_once() + vector._client.create_table_with_index_params.assert_called_once() + index_params.add_index.assert_called_once() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + oceanbase_module.redis_client.set.assert_called_once() + + +def test_create_collection_error_paths(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.return_value = [] + with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "0"]], + RuntimeError("no privilege"), + ] + with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]] + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid") + with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"): + vector._create_collection() + + +def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + RuntimeError("fulltext failed"), + ] + with pytest.raises(Exception, match="Failed to add fulltext index"): + vector._create_collection() + + vector._hybrid_search_enabled = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + SQLAlchemyError("metadata index failed"), + ] + vector._create_collection() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + + +def test_check_hybrid_search_support_false_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=False) + vector._client = MagicMock() + assert vector._check_hybrid_search_support() is False + + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_add_texts_batches_refresh_and_exceptions(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2) + vector._client = MagicMock() + vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + + vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert vector._client.insert.call_count == 2 + vector._client.refresh_index.assert_called_once() + + vector._client.insert.reset_mock() + vector._client.refresh_index.reset_mock() + vector._client.insert.side_effect = RuntimeError("insert failed") + with pytest.raises(Exception, match="Failed to insert batch"): + vector.add_texts([docs[0]], [[0.1]]) + + vector._client.insert.side_effect = None + vector._client.insert.return_value = None + vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed") + vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1) + vector._get_uuids.return_value = ["id-1"] + vector.add_texts([docs[0]], [[0.1]]) + + +def test_text_exists_and_delete_by_ids(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.get.return_value = SimpleNamespace(rowcount=1) + assert vector.text_exists("id-1") is True + + vector._client.get.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to check text existence"): + vector.text_exists("id-1") + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + + vector._client.delete.side_effect = None + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once() + + vector._client.delete.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete documents"): + vector.delete_by_ids(["id-1"]) + + +def test_get_ids_and_delete_by_metadata_field(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + execute_result = [("id-1",), ("id-2",)] + + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value = execute_result + vector._client.engine.connect.return_value = conn + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + + with pytest.raises(Exception, match="Failed to query documents by metadata field"): + vector.get_ids_by_metadata_field("bad key!", "doc-1") + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_not_called() + + +def test_search_by_full_text_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hybrid_search_enabled = True + vector.field_exists = MagicMock(return_value=False) + + assert vector.search_by_full_text("query") == [] + + vector.field_exists.return_value = True + vector._client = MagicMock() + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = tx + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)] + vector._client.engine.connect.return_value = conn + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + with pytest.raises(Exception, match="Full-text search failed"): + vector.search_by_full_text("query", top_k=0) + + +def test_search_by_vector_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector( + [0.1, 0.2], + ef_search=10, + top_k=3, + score_threshold=0.1, + document_ids_filter=["good_id"], + ) + assert docs == ["doc"] + vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10) + + with pytest.raises(ValueError, match="Invalid score_threshold parameter"): + vector.search_by_vector([0.1], score_threshold="x") + + vector._client.ann_search.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Vector search failed"): + vector.search_by_vector([0.1], score_threshold=0.1) + + +def test_get_distance_func_and_distance_to_score_errors(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._get_distance_func() is oceanbase_module.cosine_distance + + vector._config = SimpleNamespace(metric_type="unknown") + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._distance_to_score(0.1) + + +def test_delete_success_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector.delete() + vector._client.drop_table_if_exist.assert_called_once_with("collection_1") + + vector._client.drop_table_if_exist.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete collection"): + vector.delete() + + +def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch): + factory = oceanbase_module.OceanBaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000) + + with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].args[0] == "existing_collection" + assert vector_cls.call_args_list[1].args[0] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py new file mode 100644 index 0000000000..6641dbe4a0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py @@ -0,0 +1,400 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def opengauss_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opengauss.opengauss as module + + return importlib.reload(module) + + +def _config(module, *, enable_pq=False): + return module.OpenGaussConfig( + host="localhost", + port=6600, + user="postgres", + password="password", + database="dify", + min_connection=1, + max_connection=5, + enable_pq=enable_pq, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENGAUSS_HOST is required"), + ("port", 0, "config OPENGAUSS_PORT is required"), + ("user", "", "config OPENGAUSS_USER is required"), + ("password", "", "config OPENGAUSS_PASSWORD is required"), + ("database", "", "config OPENGAUSS_DATABASE is required"), + ("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"), + ("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"), + ], +) +def test_opengauss_config_validation(opengauss_module, field, value, message): + values = _config(opengauss_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module): + values = _config(opengauss_module).model_dump() + values["min_connection"] = 6 + values["max_connection"] = 5 + + with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + + assert vector.table_name == "embedding_collection_1" + assert vector.get_type() == "opengauss" + assert vector.pool is pool + + +def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("enable_pq=on" in sql for sql in executed_sql) + assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql) + opengauss_module.redis_client.set.assert_called_once() + + +def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(3072) + + cursor.execute.assert_not_called() + opengauss_module.redis_client.set.assert_called_once() + + +def test_search_by_vector_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + +def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector.delete_by_ids([]) + + vector._get_cursor.assert_not_called() + + +def test_get_cursor_closes_commits_and_returns_connection(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + pool = MagicMock() + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + vector.pool = pool + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_calls_collection_insert_and_index(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + vector._create_index = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + vector._create_index.assert_called_once_with(2) + + +def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector._create_index(1536) + + vector._get_cursor.assert_not_called() + opengauss_module.redis_client.set.assert_not_called() + + +def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql) + + +def test_add_texts_uses_execute_values(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + opengauss_module.psycopg2.extras.execute_values.reset_mock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + docs = [ + Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}), + SimpleNamespace(page_content="text-2", metadata=None), + ] + monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid") + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["seg-1"] + opengauss_module.psycopg2.extras.execute_values.assert_called_once() + + +def test_text_exists_and_get_by_ids(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("seg-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("seg-1") is True + docs = vector.get_by_ids(["seg-1", "seg-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + +def test_delete_and_metadata_field_queries(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + vector.delete_by_ids(["seg-1", "seg-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql) + assert any("meta->>%s = %s" in query for query in sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql) + + +def test_search_by_vector_and_full_text(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.6), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_search_by_full_text_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=0) + + +def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(1536) + cursor.execute.assert_not_called() + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(1536) + cursor.execute.assert_called_once() + opengauss_module.redis_client.set.assert_called_once() + + +def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch): + factory = opengauss_module.OpenGaussFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False) + + with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py new file mode 100644 index 0000000000..1030158dd1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py @@ -0,0 +1,360 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "basic", + "user": "admin", + "password": "secret", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENSEARCH_HOST is required"), + ("port", 0, "config OPENSEARCH_PORT is required"), + ], +) +def test_config_validation_required_fields(opensearch_module, field, value, message): + values = _config(opensearch_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_config_validation_for_aws_auth_and_https_fields(opensearch_module): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "aws_managed_iam", + "user": "admin", + "password": "secret", + } + with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"): + opensearch_module.OpenSearchConfig.model_validate(values) + + values = _config(opensearch_module).model_dump() + values["OPENSEARCH_SECURE"] = False + values["OPENSEARCH_VERIFY_CERTS"] = True + with pytest.raises(ValidationError, match="verify_certs=True requires secure"): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-east-1", + aws_service="es", + ) + auth = config.create_aws_managed_iam_auth() + + assert auth.credentials == "creds" + assert auth.region == "us-east-1" + assert auth.service == "es" + + +def test_to_opensearch_params_supports_basic_and_aws(opensearch_module): + basic_params = _config(opensearch_module).to_opensearch_params() + assert basic_params["http_auth"] == ("admin", "secret") + + aws_config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-west-2", + aws_service="es", + ) + with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"): + aws_params = aws_config.to_opensearch_params() + + assert aws_params["http_auth"] == "iam-auth" + + +def test_init_and_create_delegate_calls(opensearch_module): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "opensearch" + vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es")) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id")) + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.1], [0.2]]) + actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert len(actions) == 2 + assert all("_id" in action for action in actions) + + vector._client_config.aws_service = "aoss" + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.3], [0.4]]) + aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert all("_id" not in action for action in aoss_actions) + + +def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector._client.search.return_value = {"hits": {"hits": []}} + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + +def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + opensearch_module.helpers.bulk.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["doc-1"]) + opensearch_module.helpers.bulk.assert_not_called() + + vector._client.indices.exists.return_value = True + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None]) + vector.delete_by_ids(["doc-1", "doc-2"]) + opensearch_module.helpers.bulk.assert_called_once() + + opensearch_module.helpers.bulk.reset_mock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"]) + opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError( + [{"delete": {"status": 404, "_id": "es-404"}}] + ) + vector.delete_by_ids(["doc-404"]) + assert opensearch_module.helpers.bulk.call_count == 1 + + opensearch_module.helpers.bulk.side_effect = None + + +def test_delete_and_text_exists(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True) + + vector._client.get.return_value = {"_id": "id-1"} + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("not found") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validates_and_builds_documents(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + + with pytest.raises(ValueError, match="query_vector should be a list"): + vector.search_by_vector("not-a-list") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, 1]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-1", + opensearch_module.Field.METADATA_KEY: None, + }, + "_score": 0.9, + }, + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-2", + opensearch_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + "_score": 0.1, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"]) + query = vector._client.search.call_args.kwargs["body"] + assert "script_score" in query["query"] + + +def test_search_by_vector_reraises_client_error(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.search_by_vector([0.1, 0.2]) + + +def test_search_by_full_text_and_filters(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.METADATA_KEY: {"doc_id": "1"}, + opensearch_module.Field.VECTOR: [0.1], + opensearch_module.Field.CONTENT_KEY: "matched text", + } + }, + ] + } + } + + docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "matched text" + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}] + + +def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock()) + + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1)) + vector._client.indices.create.reset_mock() + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_called_once() + index_body = vector._client.indices.create.call_args.kwargs["body"] + assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2 + opensearch_module.redis_client.set.assert_called() + + +def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch): + factory = opensearch_module.OpenSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None) + + with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py new file mode 100644 index 0000000000..817a7d342b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py @@ -0,0 +1,375 @@ +import array +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_oracle_modules(): + jieba = types.ModuleType("jieba") + jieba_posseg = types.ModuleType("jieba.posseg") + jieba_posseg.cut = MagicMock(return_value=[]) + jieba.posseg = jieba_posseg + + oracledb = types.ModuleType("oracledb") + oracledb_connection = types.ModuleType("oracledb.connection") + + class Connection: + pass + + oracledb_connection.Connection = Connection + oracledb.defaults = SimpleNamespace(fetch_lobs=True) + oracledb.DB_TYPE_VECTOR = object() + oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock())) + oracledb.connect = MagicMock() + + return { + "jieba": jieba, + "jieba.posseg": jieba_posseg, + "oracledb": oracledb, + "oracledb.connection": oracledb_connection, + } + + +def _connection_with_cursor(cursor): + cursor_ctx = MagicMock() + cursor_ctx.__enter__.return_value = cursor + cursor_ctx.__exit__.return_value = None + + connection = MagicMock() + connection.__enter__.return_value = connection + connection.__exit__.return_value = None + connection.cursor.return_value = cursor_ctx + return connection + + +@pytest.fixture +def oracle_module(monkeypatch): + for name, module in _build_fake_oracle_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.oracle.oraclevector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "user": "system", + "password": "oracle", + "dsn": "oracle:1521/freepdb1", + "is_autonomous": False, + } + values.update(overrides) + return module.OracleVectorConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("user", "", "config ORACLE_USER is required"), + ("password", "", "config ORACLE_PASSWORD is required"), + ("dsn", "", "config ORACLE_DSN is required"), + ], +) +def test_oracle_config_validation_required_fields(oracle_module, field, value, message): + values = _config(oracle_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oracle_module.OracleVectorConfig.model_validate(values) + + +def test_oracle_config_validation_autonomous_requirements(oracle_module): + with pytest.raises(ValidationError, match="config_dir is required"): + oracle_module.OracleVectorConfig.model_validate( + {"user": "u", "password": "p", "dsn": "d", "is_autonomous": True} + ) + + +def test_init_and_get_type(oracle_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool)) + vector = oracle_module.OracleVector("collection_1", _config(oracle_module)) + + assert vector.get_type() == "oracle" + assert vector.table_name == "embedding_collection_1" + assert vector.pool is pool + + +def test_numpy_converters_and_type_handlers(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + + in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64)) + in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32)) + in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8)) + assert in_float64.typecode == "d" + assert in_float32.typecode == "f" + assert in_int8.typecode == "b" + + cursor = MagicMock() + vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2) + cursor.var.assert_called_with( + oracle_module.oracledb.DB_TYPE_VECTOR, + arraysize=2, + inconverter=vector.numpy_converter_in, + ) + + metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR) + cursor.arraysize = 3 + vector.output_type_handler(cursor, metadata) + cursor.var.assert_called_with( + metadata.type_code, + arraysize=3, + outconverter=vector.numpy_converter_out, + ) + + out_int8 = vector.numpy_converter_out(array.array("b", [1])) + assert out_int8.dtype == numpy.int8 + out_float32 = vector.numpy_converter_out(array.array("f", [1.0])) + assert out_float32.dtype == numpy.float32 + out_float64 = vector.numpy_converter_out(array.array("d", [1.0])) + assert out_float64.dtype == numpy.float64 + + +def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch): + connect = MagicMock(return_value="connection") + monkeypatch.setattr(oracle_module.oracledb, "connect", connect) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.config = _config(oracle_module) + assert vector._get_connection() == "connection" + connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1") + + vector.config = _config( + oracle_module, + is_autonomous=True, + config_dir="/wallet", + wallet_location="/wallet", + wallet_password="pw", + ) + vector._get_connection() + assert connect.call_args.kwargs["config_dir"] == "/wallet" + assert connect.call_args.kwargs["wallet_location"] == "/wallet" + + +def test_create_delegates_collection_and_insert(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.execute.side_effect = [None, RuntimeError("insert failed")] + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + assert cursor.execute.call_count == 2 + assert connection.commit.call_count >= 1 + connection.close.assert_called() + + +def test_text_exists_and_get_by_ids(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.pool = MagicMock() + + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + vector.pool.release.assert_called_once() + assert vector.get_by_ids([]) == [] + + +def test_delete_methods(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + vector.delete_by_ids([]) + vector._get_connection.assert_not_called() + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql) + assert any("JSON_VALUE(meta" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_with_threshold_and_filter(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)]) + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=0, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "fetch first 4 rows only" in sql + assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql + + +def _fake_nltk_module(*, missing_data=False): + nltk = types.ModuleType("nltk") + nltk_corpus = types.ModuleType("nltk.corpus") + + class _Data: + @staticmethod + def find(_path): + if missing_data: + raise LookupError("missing") + return True + + nltk.data = _Data() + nltk.word_tokenize = lambda text: text.split() + nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"]) + return nltk, nltk_corpus + + +def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("张", "nr"), ("三", "nr"), ("。", "x")])) + zh_docs = vector.search_by_full_text("张三", top_k=2) + assert len(zh_docs) == 1 + zh_params = cursor.execute.call_args.args[1] + assert zh_params["kk"] == "张三" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=False) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])]) + en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"]) + assert len(en_docs) == 1 + en_sql = cursor.execute.call_args.args[0] + en_params = cursor.execute.call_args.args[1] + assert "fetch first 5 rows only" in en_sql + assert "doc_id_0" in en_params + + +def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector._get_connection = MagicMock() + + empty_result = vector.search_by_full_text("") + assert empty_result[0].page_content == "" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=True) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + with pytest.raises(LookupError, match="required NLTK data package"): + vector.search_by_full_text("english query") + + +def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock()) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_not_called() + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(2) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql) + oracle_module.redis_client.set.assert_called_once() + + +def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch): + factory = oracle_module.OracleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False) + + with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 0000000000..1aec81b8ac --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,317 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_pgvecto_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSessionContext: + def __init__(self, calls, execute_results=None): + self.calls = calls + self.execute_results = execute_results or [] + self.execute = MagicMock(side_effect=self._execute_side_effect) + self.commit = MagicMock() + + def _execute_side_effect(self, *args, **kwargs): + self.calls.append((args, kwargs)) + if self.execute_results: + return self.execute_results.pop(0) + return MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(calls, execute_results=None): + def _session(_client): + return _FakeSessionContext(calls=calls, execute_results=execute_results) + + return _session + + +@pytest.fixture +def pgvecto_module(monkeypatch): + for name, module in _build_fake_pgvecto_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module + import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + + return importlib.reload(module), importlib.reload(collection_module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "postgres", + } + values.update(overrides) + return module.PgvectoRSConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config PGVECTO_RS_HOST is required"), + ("port", 0, "config PGVECTO_RS_PORT is required"), + ("user", "", "config PGVECTO_RS_USER is required"), + ("password", "", "config PGVECTO_RS_PASSWORD is required"), + ("database", "", "config PGVECTO_RS_DATABASE is required"), + ], +) +def test_pgvecto_config_validation(pgvecto_module, field, value, message): + module, _ = pgvecto_module + values = _config(module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + module.PgvectoRSConfig.model_validate(values) + + +def test_collection_base_has_expected_annotations(pgvecto_module): + _, collection_module = pgvecto_module + annotations = collection_module.CollectionORM.__annotations__ + assert {"id", "text", "meta", "vector"} <= set(annotations) + + +def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == module.VectorType.PGVECTO_RS + module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres") + assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls) + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(module.redis_client, "set", MagicMock()) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection(3) + assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None)) + vector.create_collection(3) + assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls) + module.redis_client.set.assert_called() + + +def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + runtime_calls = [] + execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] + + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + + class _InsertBuilder: + def __init__(self, table): + self.table = table + + def values(self, **kwargs): + return ("insert", kwargs) + + monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table)) + monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["uuid-1", "uuid-2"] + assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0]) + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("row-id-1",)]), + MagicMock(), + ], + ), + ) + vector.delete_by_ids(["doc-1"]) + assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + vector.delete() + assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) + + +def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + runtime_calls = [] + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("id-1",)]), + SimpleNamespace(fetchall=lambda: []), + ], + ), + ) + assert vector.text_exists("doc-1") is True + assert vector.text_exists("doc-1") is False + + class _DistanceExpr: + def label(self, _name): + return self + + class _VectorColumn: + def op(self, _operator, return_type=None): + def _call(_query_vector): + return _DistanceExpr() + + return _call + + class _MetaFilter: + def in_(self, values): + return ("in", values) + + class _MetaColumn: + def __getitem__(self, _item): + return _MetaFilter() + + class _Stmt: + def __init__(self): + self.where_called = False + + def limit(self, _value): + return self + + def order_by(self, _value): + return self + + def where(self, _value): + self.where_called = True + return self + + stmt = _Stmt() + monkeypatch.setattr(module, "select", lambda *_args: stmt) + + vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn()) + rows = [ + (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), + (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), + ] + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert stmt.where_called is True + assert vector.search_by_full_text("hello") == [] + + +def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + factory = module.PGVectoRSFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432) + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres") + + embeddings = MagicMock() + embeddings.embed_query.return_value = [0.1, 0.2, 0.3] + + with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py index 4998a9858f..7505262eb7 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -1,16 +1,19 @@ -import unittest +from contextlib import contextmanager +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module from core.rag.datasource.vdb.pgvector.pgvector import ( PGVector, PGVectorConfig, ) +from core.rag.models.document import Document -class TestPGVector(unittest.TestCase): - def setUp(self): +class TestPGVector: + def setup_method(self, method): self.config = PGVectorConfig( host="localhost", port=5432, @@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override): PGVectorConfig(**config) -if __name__ == "__main__": - unittest.main() +def test_create_delegates_collection_creation_and_insert(): + vector = PGVector.__new__(PGVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["doc-a"]) + docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["doc-a"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid") + execute_values = MagicMock() + monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values) + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + execute_values.assert_called_once() + + +def test_text_get_and_delete_methods(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + class _UndefinedTableError(Exception): + pass + + monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError) + cursor.execute.side_effect = _UndefinedTableError("missing") + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + vector.delete_by_ids(["doc-1"]) + + +def test_search_by_vector_supports_filter_and_threshold(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "meta->>'document_id' in ('d-1')" in sql + + +def test_search_by_full_text_branches_for_bigm_and_standard(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + vector.pg_bigm = False + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.7) + standard_sql = cursor.execute.call_args.args[0] + assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql + + cursor.execute.reset_mock() + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)]) + vector.pg_bigm = True + vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"]) + assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0] + assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0] + + +def test_pgvector_factory_initializes_expected_collection_name(monkeypatch): + factory = pgvector_module.PGVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False) + + with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py new file mode 100644 index 0000000000..bd8df520ba --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py @@ -0,0 +1,269 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def vastbase_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VastbaseVectorConfig( + host="localhost", + port=5432, + user="dify", + password="secret", + database="dify", + min_connection=1, + max_connection=5, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config VASTBASE_HOST is required"), + ("port", 0, "config VASTBASE_PORT is required"), + ("user", "", "config VASTBASE_USER is required"), + ("password", "", "config VASTBASE_PASSWORD is required"), + ("database", "", "config VASTBASE_DATABASE is required"), + ("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"), + ("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"), + ], +) +def test_vastbase_config_validation(vastbase_module, field, value, message): + values = _config(vastbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + vastbase_module.VastbaseVectorConfig.model_validate(values) + + +def test_vastbase_config_rejects_invalid_connection_window(vastbase_module): + with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"): + vastbase_module.VastbaseVectorConfig.model_validate( + { + "host": "localhost", + "port": 5432, + "user": "dify", + "password": "secret", + "database": "dify", + "min_connection": 6, + "max_connection": 5, + } + ) + + +def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + + vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module)) + assert vector.get_type() == "vastbase" + assert vector.table_name == "embedding_collection_1" + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_and_add_texts(vastbase_module, monkeypatch): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + vector._create_collection = MagicMock() + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid") + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert ids == ["doc-a", "generated-uuid"] + vastbase_module.psycopg2.extras.execute_values.assert_called_once() + + vector.add_texts = MagicMock(return_value=["doc-a"]) + result = vector.create(docs, [[0.1], [0.2], [0.3]]) + vector._create_collection.assert_called_once_with(1) + assert result == ["doc-a"] + + +def test_text_get_delete_and_metadata_methods(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_ids([]) + vector.delete_by_ids(["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql) + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_and_full_text(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.8), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock()) + + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + cursor.execute.assert_not_called() + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(17000) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql) + + cursor.execute.reset_mock() + vector._create_collection(3) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql) + vastbase_module.redis_client.set.assert_called() + + +def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch): + factory = vastbase_module.VastbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5) + + with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py new file mode 100644 index 0000000000..0408506563 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py @@ -0,0 +1,328 @@ +import importlib +import os +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_qdrant_modules(): + qdrant_client = types.ModuleType("qdrant_client") + qdrant_http = types.ModuleType("qdrant_client.http") + qdrant_http_models = types.ModuleType("qdrant_client.http.models") + qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions") + qdrant_local_pkg = types.ModuleType("qdrant_client.local") + qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local") + + class UnexpectedResponseError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + class FilterSelector: + def __init__(self, filter): + self.filter = filter + + class HnswConfigDiff: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class TextIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class PointStruct: + def __init__(self, **kwargs): + self.id = kwargs["id"] + self.vector = kwargs["vector"] + self.payload = kwargs["payload"] + + class Filter: + def __init__(self, must=None): + self.must = must or [] + + class FieldCondition: + def __init__(self, key, match): + self.key = key + self.match = match + + class MatchValue: + def __init__(self, value): + self.value = value + + class MatchAny: + def __init__(self, any): + self.any = any + + class MatchText: + def __init__(self, text): + self.text = text + + class _Distance(UserDict): + def __getitem__(self, key): + return key + + class QdrantClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[])) + self.create_collection = MagicMock() + self.create_payload_index = MagicMock() + self.upsert = MagicMock() + self.delete = MagicMock() + self.delete_collection = MagicMock() + self.retrieve = MagicMock(return_value=[]) + self.search = MagicMock(return_value=[]) + self.scroll = MagicMock(return_value=([], None)) + + class QdrantLocal(QdrantClient): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._load = MagicMock() + + qdrant_client.QdrantClient = QdrantClient + qdrant_http_models.FilterSelector = FilterSelector + qdrant_http_models.HnswConfigDiff = HnswConfigDiff + qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD") + qdrant_http_models.TextIndexParams = TextIndexParams + qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT") + qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL") + qdrant_http_models.VectorParams = VectorParams + qdrant_http_models.Distance = _Distance() + qdrant_http_models.PointStruct = PointStruct + qdrant_http_models.Filter = Filter + qdrant_http_models.FieldCondition = FieldCondition + qdrant_http_models.MatchValue = MatchValue + qdrant_http_models.MatchAny = MatchAny + qdrant_http_models.MatchText = MatchText + qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError + + qdrant_http.models = qdrant_http_models + qdrant_local_mod.QdrantLocal = QdrantLocal + qdrant_local_pkg.qdrant_local = qdrant_local_mod + + return { + "qdrant_client": qdrant_client, + "qdrant_client.http": qdrant_http, + "qdrant_client.http.models": qdrant_http_models, + "qdrant_client.http.exceptions": qdrant_http_exceptions, + "qdrant_client.local": qdrant_local_pkg, + "qdrant_client.local.qdrant_local": qdrant_local_mod, + } + + +@pytest.fixture +def qdrant_module(monkeypatch): + for name, module in _build_fake_qdrant_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.qdrant.qdrant_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "endpoint": "http://localhost:6333", + "api_key": "api-key", + "timeout": 20, + "root_path": "/tmp", + "grpc_port": 6334, + "prefer_grpc": False, + "replication_factor": 1, + "write_consistency_factor": 1, + } + values.update(overrides) + return module.QdrantConfig.model_validate(values) + + +def test_qdrant_config_to_params(qdrant_module): + url_params = _config(qdrant_module).to_qdrant_params().model_dump() + assert url_params["url"] == "http://localhost:6333" + assert url_params["verify"] is False + + path_config = _config(qdrant_module, endpoint="path:storage") + assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage") + + with pytest.raises(ValueError, match="Root path is not set"): + _config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params() + + +def test_init_and_basic_behaviour(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.get_type() == qdrant_module.VectorType.QDRANT + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1" + + docs = [Document(page_content="a", metadata={"doc_id": "a"})] + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with("collection_1", 1) + vector.add_texts.assert_called_once() + + +def test_create_collection_and_add_texts(qdrant_module, monkeypatch): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.get_collections.return_value = SimpleNamespace(collections=[]) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_called_once() + assert vector._client.create_payload_index.call_count == 4 + qdrant_module.redis_client.set.assert_called_once() + + # add_texts and generated batches + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.upsert.call_count == 1 + + payloads = qdrant_module.QdrantVector._build_payloads( + ["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + assert payloads[0]["group_id"] == "g1" + with pytest.raises(ValueError, match="At least one of the texts is None"): + qdrant_module.QdrantVector._build_payloads( + [None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + + +def test_delete_and_exists_paths(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete() + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete() + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = None + + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")]) + assert vector.text_exists("id-1") is False + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]) + vector._client.retrieve.return_value = [{"id": "id-1"}] + assert vector.text_exists("id-1") is True + + +def test_search_and_helper_methods(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.search_by_vector([0.1], score_threshold=1.0) == [] + + vector._client.search.return_value = [ + SimpleNamespace(payload=None, score=0.9, vector=[0.1]), + SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]), + ] + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + # full text search: keyword split, dedup and top_k limit + scroll_results = [ + ( + [ + SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]), + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ( + [ + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ] + vector._client.scroll.side_effect = scroll_results + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert vector.search_by_full_text(" ", top_k=2) == [] + + local_client = qdrant_module.QdrantLocal() + vector._client = local_client + vector._reload_if_needed() + local_client._load.assert_called_once() + + doc = vector._document_from_scored_point( + SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]), + "page_content", + "metadata", + ) + assert doc.page_content == "doc" + + +def test_qdrant_factory_paths(qdrant_module, monkeypatch): + factory = qdrant_module.QdrantVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + collection_binding_id=None, + index_struct_dict=None, + index_struct=None, + ) + monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root"))) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1) + + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset.index_struct is not None + + # collection binding lookup path + dataset.collection_binding_id = "binding-1" + dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}} + monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + qdrant_module.db.session.scalars = MagicMock( + return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION")) + ) + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION" + + qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None)) + with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py new file mode 100644 index 0000000000..ca8cd5e514 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -0,0 +1,303 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_relyt_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSession: + def __init__(self, execute_result=None): + self.execute_result = execute_result or MagicMock(fetchall=lambda: []) + self.execute = MagicMock(return_value=self.execute_result) + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +@pytest.fixture +def relyt_module(monkeypatch): + for name, module in _build_fake_relyt_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.relyt.relyt_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "relyt", + } + values.update(overrides) + return module.RelytConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config RELYT_HOST is required"), + ("port", 0, "config RELYT_PORT is required"), + ("user", "", "config RELYT_USER is required"), + ("password", "", "config RELYT_PASSWORD is required"), + ("database", "", "config RELYT_DATABASE is required"), + ], +) +def test_relyt_config_validation(relyt_module, field, value, message): + values = _config(relyt_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + relyt_module.RelytConfig.model_validate(values) + + +def test_init_get_type_and_create_delegate(relyt_module, monkeypatch): + engine = MagicMock() + monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine)) + vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1") + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == relyt_module.VectorType.RELYT + assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt" + assert vector.embedding_dimension == 2 + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock()) + + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + session.execute.assert_not_called() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] + assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) + assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql) + assert any("CREATE INDEX" in sql for sql in executed_sql) + relyt_module.redis_client.set.assert_called_once() + + +def test_add_texts_and_metadata_queries(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector._group_id = "group-1" + vector.client = MagicMock() + + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + + monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "d-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert conn.execute.call_count >= 1 + first_insert_values = conn.execute.call_args.args[0].compile().params + assert "group_id" in str(first_insert_values) + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"] + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +# 1. delete_by_uuids: success and connect error +def test_delete_by_uuids_success_and_connect_error(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + with pytest.raises(ValueError, match="No ids provided"): + vector.delete_by_uuids(None) + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + assert vector.delete_by_uuids(["id-1"]) is True + vector.client.connect.side_effect = RuntimeError("boom") + assert vector.delete_by_uuids(["id-1"]) is False + + +# 2. delete_by_metadata_field calls delete_by_uuids +def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_uuids.assert_called_once_with(["id-1"]) + + +# 3. delete_by_ids translates to uuids +def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_ids(["doc-1", "doc-2"]) + vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"]) + + +# 4. text_exists True +def test_text_exists_true(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is True + + +# 5. text_exists False +def test_text_exists_false(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is False + + +# 6. similarity_search_with_score_by_vector returns Documents and scores +def test_similarity_search_with_score_by_vector(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + result_rows = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8), + ] + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = result_rows + vector.client.connect.return_value = conn + similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]}) + assert len(similarities) == 2 + assert similarities[0][0].page_content == "doc-a" + + +# 7. search_by_vector filters by score and ids +def test_search_by_vector_filters_by_score_and_ids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.similarity_search_with_score_by_vector = MagicMock( + return_value=[ + (Document(page_content="a", metadata={"doc_id": "1"}), 0.1), + (Document(page_content="b", metadata={}), 0.9), + ] + ) + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert vector.search_by_full_text("query") == [] + + +# 8. delete commits session +def test_delete_commits_session(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete() + session.commit.assert_called_once() + + +def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): + factory = relyt_module.RelytVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432) + monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt") + + with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py new file mode 100644 index 0000000000..e3b6676d9b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py @@ -0,0 +1,316 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_tablestore_module(): + tablestore = types.ModuleType("tablestore") + + class _BatchGetRowRequest: + def __init__(self): + self.items = [] + + def add(self, item): + self.items.append(item) + + class _TableInBatchGetRowItem: + def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver): + self.table_name = table_name + self.rows_to_get = rows_to_get + self.columns_to_get = columns_to_get + + class _Row: + def __init__(self, primary_key, attribute_columns=None): + self.primary_key = primary_key + self.attribute_columns = attribute_columns or [] + + class _Client: + def __init__(self, *_args): + self.list_table = MagicMock(return_value=[]) + self.create_table = MagicMock() + self.list_search_index = MagicMock(return_value=[]) + self.create_search_index = MagicMock() + self.delete_search_index = MagicMock() + self.delete_table = MagicMock() + self.put_row = MagicMock() + self.delete_row = MagicMock() + self.get_row = MagicMock(return_value=(None, None, None)) + self.batch_get_row = MagicMock() + self.search = MagicMock() + + tablestore.OTSClient = _Client + tablestore.BatchGetRowRequest = _BatchGetRowRequest + tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem + tablestore.Row = _Row + tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema) + tablestore.TableOptions = lambda: ("table_options",) + tablestore.CapacityUnit = lambda read, write: ("capacity", read, write) + tablestore.ReservedThroughput = lambda cap: ("reserved", cap) + tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs) + tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs) + tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas) + tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs) + tablestore.TermQuery = lambda key, value: ("term_query", key, value) + tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs) + tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.TermsQuery = lambda key, values: ("terms_query", key, values) + tablestore.Sort = lambda **kwargs: ("sort", kwargs) + tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs) + tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs) + + tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD") + tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD") + tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32") + tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE") + tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX") + tablestore.SortOrder = SimpleNamespace(DESC="DESC") + return tablestore + + +@pytest.fixture +def tablestore_module(monkeypatch): + fake_module = _build_fake_tablestore_module() + monkeypatch.setitem(sys.modules, "tablestore", fake_module) + + import core.rag.datasource.vdb.tablestore.tablestore_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "access_key_id": "ak", + "access_key_secret": "sk", + "instance_name": "instance", + "endpoint": "endpoint", + "normalize_full_text_bm25_score": False, + } + values.update(overrides) + return module.TableStoreConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("access_key_id", "", "config ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config ACCESS_KEY_SECRET is required"), + ("instance_name", "", "config INSTANCE_NAME is required"), + ("endpoint", "", "config ENDPOINT is required"), + ], +) +def test_tablestore_config_validation(tablestore_module, field, value, message): + values = _config(tablestore_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + tablestore_module.TableStoreConfig.model_validate(values) + + +def test_init_and_basic_delegation(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + assert vector.get_type() == tablestore_module.VectorType.TABLESTORE + assert vector._table_name == "collection_1" + assert vector._index_name == "collection_1_idx" + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]]) + + vector.create_collection([[0.1, 0.2]]) + assert vector._create_collection.call_count == 2 + + +def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + # get_by_ids + ok_item = SimpleNamespace( + is_ok=True, + row=SimpleNamespace( + attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)] + ), + ) + fail_item = SimpleNamespace(is_ok=False, row=None) + batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item]) + vector._tablestore_client.batch_get_row.return_value = batch_resp + docs = vector.get_by_ids(["id-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + + # text_exists + vector._tablestore_client.get_row.return_value = (None, object(), None) + assert vector.text_exists("id-1") is True + vector._tablestore_client.get_row.return_value = (None, None, None) + assert vector.text_exists("id-1") is False + + # delete wrappers + vector._delete_row = MagicMock() + vector.delete_by_ids([]) + vector._delete_row.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._delete_row.call_count == 2 + + vector._search_by_metadata = MagicMock(return_value=["id-a"]) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"] + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-a"]) + + vector._search_by_vector = MagicMock(return_value=["vec-doc"]) + vector._search_by_full_text = MagicMock(return_value=["fts-doc"]) + assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"] + assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"] + + vector._delete_table_if_exist = MagicMock() + vector.delete() + vector._delete_table_if_exist.assert_called_once() + + +def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table_if_not_exist = MagicMock() + vector._create_search_index_if_not_exist = MagicMock() + vector._create_collection(3) + vector._create_table_if_not_exist.assert_not_called() + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(3) + vector._create_table_if_not_exist.assert_called_once() + vector._create_search_index_if_not_exist.assert_called_once_with(3) + tablestore_module.redis_client.set.assert_called_once() + + vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module)) + vector._tablestore_client.list_table.return_value = ["collection_2"] + assert vector._create_table_if_not_exist() is None + vector._tablestore_client.list_table.return_value = [] + vector._create_table_if_not_exist() + vector._tablestore_client.create_table.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")] + assert vector._create_search_index_if_not_exist(3) is None + vector._tablestore_client.list_search_index.return_value = [] + vector._create_search_index_if_not_exist(3) + vector._tablestore_client.create_search_index.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")] + vector._delete_table_if_exist() + assert vector._tablestore_client.delete_search_index.call_count == 2 + vector._tablestore_client.delete_table.assert_called_once_with("collection_2") + + vector._delete_search_index() + vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx") + + +def test_write_row_and_search_helpers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + vector._write_row( + "id-1", + { + "page_content": "hello", + "vector": [0.1, 0.2], + "metadata": {"doc_id": "d-1", "document_id": "doc-1"}, + }, + ) + put_row_call = vector._tablestore_client.put_row.call_args + assert put_row_call.args[0] == "collection_1" + attrs = put_row_call.args[1].attribute_columns + assert any(item[0] == "metadata_tags" for item in attrs) + + vector._delete_row("id-1") + vector._tablestore_client.delete_row.assert_called_once() + + # metadata search pagination + first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next") + second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"") + vector._tablestore_client.search.side_effect = [first_page, second_page] + ids = vector._search_by_metadata("document_id", "doc-1") + assert ids == ["row-1", "row-2"] + vector._tablestore_client.search.side_effect = None + + # vector search + hit1 = SimpleNamespace( + score=0.9, + row=( + None, + [("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))], + ), + ) + hit2 = SimpleNamespace( + score=0.2, + row=( + None, + [("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))], + ), + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2]) + docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0) + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0 + + # full text search with and without normalized score filter + vector._normalize_full_text_bm25_score = True + hit3 = SimpleNamespace( + score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))]) + ) + hit4 = SimpleNamespace( + score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))]) + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4]) + docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2) + assert len(docs) == 1 + assert "score" in docs[0].metadata + + vector._normalize_full_text_bm25_score = False + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3]) + docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0) + assert len(docs) == 1 + assert "score" not in docs[0].metadata + + +def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch): + factory = tablestore_module.TableStoreVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True) + + with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py new file mode 100644 index 0000000000..d8f35a6019 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py @@ -0,0 +1,309 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_tencent_modules(): + tcvdb_text = types.ModuleType("tcvdb_text") + tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder") + tcvectordb = types.ModuleType("tcvectordb") + tcvectordb_model = types.ModuleType("tcvectordb.model") + tcvectordb_document = types.ModuleType("tcvectordb.model.document") + tcvectordb_index = types.ModuleType("tcvectordb.model.index") + tcvectordb_enum = types.ModuleType("tcvectordb.model.enum") + + class _BM25Encoder: + def encode_texts(self, text): + return {"encoded_text": text} + + def encode_queries(self, query): + return {"encoded_query": query} + + @classmethod + def default(cls, _lang): + return cls() + + class VectorDBError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + class RPCVectorDBClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_database_if_not_exists = MagicMock() + self.exists_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[])) + self.create_collection = MagicMock() + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.search = MagicMock(return_value=[]) + self.hybrid_search = MagicMock(return_value=[]) + self.drop_collection = MagicMock() + + class _Document: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class _HNSWSearchParams: + def __init__(self, ef): + self.ef = ef + + class _AnnSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _KeywordSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _WeightedRerank: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Filter: + @staticmethod + def in_(field, values): + return ("in", field, values) + + def __init__(self, condition): + self.condition = condition + + _Filter.In = staticmethod(_Filter.in_) + + class _HNSWParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _FilterIndex: + def __init__(self, *args): + self.args = args + + class _VectorIndex: + def __init__(self, *args): + self.args = args + + class _SparseIndex: + def __init__(self, **kwargs): + self.kwargs = kwargs + + tcvectordb_enum.IndexType = SimpleNamespace( + __members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"}, + PRIMARY_KEY="PRIMARY_KEY", + FILTER="FILTER", + SPARSE_INVERTED="SPARSE", + ) + tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP") + tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector") + + tcvectordb_document.Document = _Document + tcvectordb_document.HNSWSearchParams = _HNSWSearchParams + tcvectordb_document.AnnSearch = _AnnSearch + tcvectordb_document.Filter = _Filter + tcvectordb_document.KeywordSearch = _KeywordSearch + tcvectordb_document.WeightedRerank = _WeightedRerank + + tcvectordb_index.HNSWParams = _HNSWParams + tcvectordb_index.FilterIndex = _FilterIndex + tcvectordb_index.VectorIndex = _VectorIndex + tcvectordb_index.SparseIndex = _SparseIndex + + tcvdb_text_encoder.BM25Encoder = _BM25Encoder + + tcvectordb_model.document = tcvectordb_document + tcvectordb_model.enum = tcvectordb_enum + tcvectordb_model.index = tcvectordb_index + + tcvectordb.RPCVectorDBClient = RPCVectorDBClient + tcvectordb.VectorDBException = VectorDBError + + return { + "tcvdb_text": tcvdb_text, + "tcvdb_text.encoder": tcvdb_text_encoder, + "tcvectordb": tcvectordb, + "tcvectordb.model": tcvectordb_model, + "tcvectordb.model.document": tcvectordb_document, + "tcvectordb.model.index": tcvectordb_index, + "tcvectordb.model.enum": tcvectordb_enum, + } + + +@pytest.fixture +def tencent_module(monkeypatch): + for name, module in _build_fake_tencent_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.tencent.tencent_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "url": "http://vdb.local", + "api_key": "api-key", + "timeout": 30, + "username": "user", + "database": "db", + "index_type": "HNSW", + "metric_type": "IP", + "shard": 1, + "replicas": 2, + "max_upsert_batch_size": 2, + "enable_hybrid_search": False, + } + values.update(overrides) + return module.TencentConfig.model_validate(values) + + +def test_config_and_init_paths(tencent_module): + config = _config(tencent_module) + assert config.to_tencent_params()["url"] == "http://vdb.local" + + vector = tencent_module.TencentVector("collection_1", config) + assert vector.get_type() == tencent_module.VectorType.TENCENT + assert vector._client.kwargs["key"] == "api-key" + + vector._client.exists_collection.return_value = True + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)] + ) + vector._client_config.enable_hybrid_search = True + vector._load_collection() + assert vector._enable_hybrid_search is True + assert vector._dimension == 768 + + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=512)] + ) + vector._load_collection() + assert vector._enable_hybrid_search is False + + +def test_create_collection_branches(tencent_module, monkeypatch): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.exists_collection.return_value = True + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + vector._client.exists_collection.return_value = False + vector._client_config.index_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_collection(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_collection(3) + + vector._client_config.metric_type = "IP" + vector._client.create_collection.side_effect = [ + tencent_module.VectorDBException("fieldType:json unsupported"), + None, + ] + vector._enable_hybrid_search = True + vector._create_collection(3) + assert vector._client.create_collection.call_count == 2 + tencent_module.redis_client.set.assert_called_once() + vector._client.create_collection.side_effect = None + + +def test_create_add_delete_and_search_behaviour(tencent_module): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True)) + vector._create_collection = MagicMock() + docs = [ + Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}), + Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}), + Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + vector.create(docs, embeddings) + vector._create_collection.assert_called_once_with(1) + + vector._client.upsert.reset_mock() + vector.add_texts(docs, embeddings) + assert vector._client.upsert.call_count == 2 + first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"] + assert "sparse_vector" in first_docs[0].__dict__ + + vector._client.query.return_value = [{"id": "a"}] + assert vector.text_exists("a") is True + vector._client.query.return_value = [] + assert vector.text_exists("a") is False + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["a", "b", "c"]) + assert vector._client.delete.call_count == 2 + vector.delete_by_metadata_field("document_id", "doc-a") + assert vector._client.delete.call_count >= 3 + + vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]] + vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(vec_docs) == 1 + assert vec_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._enable_hybrid_search = False + assert vector.search_by_full_text("query") == [] + vector._enable_hybrid_search = True + vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]] + fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(fts_docs) == 1 + + # _get_search_res handles old string metadata format + compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5) + assert len(compat_docs) == 1 + assert compat_docs[0].metadata["score"] == pytest.approx(0.8) + + vector._has_collection = MagicMock(return_value=True) + vector.delete() + vector._client.drop_collection.assert_called_once() + + +def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch): + factory = tencent_module.TencentVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True) + + with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py new file mode 100644 index 0000000000..369cda39bf --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + + +class _DummyVector(BaseVector): + def __init__(self, collection_name: str, existing_ids: set[str] | None = None): + super().__init__(collection_name) + self._existing_ids = existing_ids or set() + + def get_type(self) -> str: + return "dummy" + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete_by_metadata_field(self, key: str, value: str): + return None + + def search_by_vector(self, query_vector: list[float], **kwargs): + return [] + + def search_by_full_text(self, query: str, **kwargs): + return [] + + def delete(self): + return None + + +@pytest.mark.parametrize( + ("base_method", "args"), + [ + (BaseVector.get_type, ()), + (BaseVector.create, ([], [])), + (BaseVector.add_texts, ([], [])), + (BaseVector.text_exists, ("doc-1",)), + (BaseVector.delete_by_ids, ([],)), + (BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.search_by_vector, ([0.1],)), + (BaseVector.search_by_full_text, ("query",)), + (BaseVector.delete, ()), + ], +) +def test_base_vector_default_methods_raise_not_implemented(base_method, args): + vector = _DummyVector("collection_1") + + with pytest.raises(NotImplementedError): + base_method(vector, *args) + + +def test_filter_duplicate_texts_removes_existing_docs(): + vector = _DummyVector("collection_1", existing_ids={"dup"}) + docs = [ + SimpleNamespace(page_content="keep-no-meta", metadata=None), + Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}), + Document(page_content="remove-dup", metadata={"doc_id": "dup"}), + Document(page_content="keep-unique", metadata={"doc_id": "unique"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + + assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"] + + +def test_get_uuids_and_collection_name_property(): + vector = _DummyVector("collection_1") + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + Document(page_content="c", metadata={"document_id": "d-1"}), + Document(page_content="d", metadata={"doc_id": "id-2"}), + ] + + assert vector._get_uuids(docs) == ["id-1", "id-2"] + assert vector.collection_name == "collection_1" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py new file mode 100644 index 0000000000..dd536af759 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -0,0 +1,434 @@ +import base64 +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str): + fake_module = types.ModuleType(module_path) + fake_cls = type(class_name, (), {}) + setattr(fake_module, class_name, fake_cls) + monkeypatch.setitem(sys.modules, module_path, fake_module) + return fake_cls + + +@pytest.fixture +def vector_factory_module(): + import importlib + + import core.rag.datasource.vdb.vector_factory as module + + return importlib.reload(module) + + +def test_gen_index_struct_dict(vector_factory_module): + result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict( + vector_factory_module.VectorType.WEAVIATE, + "collection_1", + ) + + assert result == { + "type": vector_factory_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "collection_1"}, + } + + +@pytest.mark.parametrize( + ("vector_type", "module_path", "class_name"), + [ + ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ( + "ALIBABACLOUD_MYSQL", + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "AlibabaCloudMySQLVectorFactory", + ), + ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ( + "ELASTICSEARCH", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "ElasticSearchVectorFactory", + ), + ( + "ELASTICSEARCH_JA", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "ElasticSearchJaVectorFactory", + ), + ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ( + "OPENSEARCH", + "core.rag.datasource.vdb.opensearch.opensearch_vector", + "OpenSearchVectorFactory", + ), + ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ( + "TIDB_ON_QDRANT", + "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "TidbOnQdrantVectorFactory", + ), + ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ( + "HUAWEI_CLOUD", + "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "HuaweiCloudVectorFactory", + ), + ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ], +) +def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): + expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name) + + result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type)) + + assert result_cls is expected_cls + + +def test_get_vector_factory_unsupported(vector_factory_module): + with pytest.raises(ValueError, match="not supported"): + vector_factory_module.Vector.get_vector_factory("unknown") + + +def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): + dataset = SimpleNamespace(id="dataset-1") + + with ( + patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), + patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), + ): + default_vector = vector_factory_module.Vector(dataset) + custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) + + assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + assert custom_vector._attributes == ["doc_id"] + assert default_vector._embeddings == "embeddings" + assert default_vector._vector_processor == "processor" + + +def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): + calls = {"vector_type": None, "init_args": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + calls["init_args"] = (dataset, attributes, embeddings) + return "vector-processor" + + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1" + ) + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH + assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings") + + +def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch): + class _Expr: + def __eq__(self, _other): + return "expr" + + calls = {"vector_type": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + return "vector-processor" + + monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))), + ) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True) + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT + + +def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch): + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = [] + vector._embeddings = "embeddings" + + with pytest.raises(ValueError, match="Vector store must be specified"): + vector._init_vector() + + +def test_create_batches_texts_and_skips_empty_input(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)] + vector._embeddings.embed_documents.side_effect = [ + [[0.1] for _ in range(1000)], + [[0.2]], + ] + + vector.create(texts=docs, trace_id="trace-1") + + assert vector._embeddings.embed_documents.call_count == 2 + assert vector._vector_processor.create.call_count == 2 + assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1" + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create(texts=None) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch): + class _Field: + def in_(self, value): + return value + + def __eq__(self, value): + return value + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]] + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace( + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")]) + ) + ), + ) + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc")) + + docs = [ + Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}), + Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}), + ] + + vector.create_multimodal(file_documents=docs, request_id="r-1") + + file_base64 = base64.b64encode(b"abc").decode() + vector._embeddings.embed_multimodal_documents.assert_called_once_with( + [{"content": file_base64, "content_type": "image", "file_id": "f-1"}] + ) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], + embeddings=[[0.1, 0.2]], + request_id="r-1", + ) + + vector._embeddings.embed_multimodal_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create_multimodal(file_documents=None) + vector._embeddings.embed_multimodal_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_add_texts_with_optional_duplicate_check(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + vector._filter_duplicate_texts = MagicMock() + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + vector._filter_duplicate_texts.return_value = [docs[0]] + vector._embeddings.embed_documents.return_value = [[0.1]] + + vector.add_texts(docs, duplicate_check=True, flag=True) + + vector._filter_duplicate_texts.assert_called_once_with(docs) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True + ) + + vector._filter_duplicate_texts.reset_mock() + vector._vector_processor.create.reset_mock() + vector._embeddings.embed_documents.return_value = [[0.2], [0.3]] + + vector.add_texts(docs, duplicate_check=False) + + vector._filter_duplicate_texts.assert_not_called() + vector._vector_processor.create.assert_called_once() + + +def test_vector_delegation_methods(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_query.return_value = [0.1, 0.2] + vector._vector_processor = MagicMock() + vector._vector_processor.text_exists.return_value = True + vector._vector_processor.search_by_vector.return_value = ["vector-doc"] + vector._vector_processor.search_by_full_text.return_value = ["text-doc"] + + assert vector.text_exists("doc-1") is True + vector.delete_by_ids(["doc-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"] + assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"] + + vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"]) + vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1") + + +def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): + class _Field: + def __eq__(self, value): + return value + + upload_query = MagicMock() + upload_query.where.return_value = upload_query + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr( + vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query)) + ) + + upload_query.first.return_value = None + assert vector.search_by_file("file-1") == [] + + upload_query.first.return_value = SimpleNamespace(key="blob-key") + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) + vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] + vector._vector_processor.search_by_vector.return_value = ["hit"] + + result = vector.search_by_file("file-2", top_k=2) + + assert result == ["hit"] + payload = vector._embeddings.embed_multimodal_query.call_args.args[0] + assert payload["content_type"] == vector_factory_module.DocType.IMAGE + assert payload["file_id"] == "file-2" + + +def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch): + delete_mock = MagicMock() + redis_delete = MagicMock() + monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1") + + vector.delete() + + delete_mock.assert_called_once() + redis_delete.assert_called_once_with("vector_indexing_collection_1") + + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="") + redis_delete.reset_mock() + vector.delete() + redis_delete.assert_not_called() + + +def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch): + model_manager = MagicMock() + model_manager.get_model_instance.return_value = "model-instance" + + monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + result = vector._get_embeddings() + + assert result == "cached-embedding" + model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=vector_factory_module.ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + +def test_filter_duplicate_texts_and_getattr(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup") + + docs = [ + SimpleNamespace(page_content="no-meta", metadata=None), + Document(page_content="empty-doc-id", metadata={"doc_id": ""}), + Document(page_content="duplicate", metadata={"doc_id": "dup"}), + Document(page_content="unique", metadata={"doc_id": "ok"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"] + + class _Processor: + def ping(self): + return "pong" + + vector._vector_processor = _Processor() + assert vector.ping() == "pong" + + with pytest.raises(AttributeError): + _ = vector.unknown_method + + vector._vector_processor = None + with pytest.raises(AttributeError, match="vector_processor"): + _ = vector.another_missing diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py new file mode 100644 index 0000000000..c25af79ae4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py @@ -0,0 +1,160 @@ +from unittest.mock import patch + +import httpx +import pytest +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse + +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import ( + TidbOnQdrantConfig, + TidbOnQdrantVector, +) + + +class TestTidbOnQdrantVectorDeleteByIds: + """Unit tests for TidbOnQdrantVector.delete_by_ids method.""" + + @pytest.fixture + def vector_instance(self): + """Create a TidbOnQdrantVector instance for testing.""" + config = TidbOnQdrantConfig( + endpoint="http://localhost:6333", + api_key="test_api_key", + ) + + with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): + vector = TidbOnQdrantVector( + collection_name="test_collection", + group_id="test_group", + config=config, + ) + return vector + + def test_delete_by_ids_with_multiple_ids(self, vector_instance): + """Test batch deletion with multiple document IDs.""" + ids = ["doc1", "doc2", "doc3"] + + vector_instance.delete_by_ids(ids) + + # Verify that delete was called once with MatchAny filter + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + # Check collection name + assert call_args[1]["collection_name"] == "test_collection" + + # Verify filter uses MatchAny with all IDs + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + assert len(filter_obj.must) == 1 + + field_condition = filter_obj.must[0] + assert field_condition.key == "metadata.doc_id" + assert isinstance(field_condition.match, rest.MatchAny) + assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"} + + def test_delete_by_ids_with_single_id(self, vector_instance): + """Test deletion with a single document ID.""" + ids = ["doc1"] + + vector_instance.delete_by_ids(ids) + + # Verify that delete was called once + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + # Verify filter uses MatchAny with single ID + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + field_condition = filter_obj.must[0] + assert isinstance(field_condition.match, rest.MatchAny) + assert field_condition.match.any == ["doc1"] + + def test_delete_by_ids_with_empty_list(self, vector_instance): + """Test deletion with empty ID list returns early without API call.""" + vector_instance.delete_by_ids([]) + + # Verify that delete was NOT called + vector_instance._client.delete.assert_not_called() + + def test_delete_by_ids_with_404_error(self, vector_instance): + """Test that 404 errors (collection not found) are handled gracefully.""" + ids = ["doc1", "doc2"] + + # Mock a 404 error + error = UnexpectedResponse( + status_code=404, + reason_phrase="Not Found", + content=b"Collection not found", + headers=httpx.Headers(), + ) + vector_instance._client.delete.side_effect = error + + # Should not raise an exception + vector_instance.delete_by_ids(ids) + + # Verify delete was called + vector_instance._client.delete.assert_called_once() + + def test_delete_by_ids_with_unexpected_error(self, vector_instance): + """Test that non-404 errors are re-raised.""" + ids = ["doc1", "doc2"] + + # Mock a 500 error + error = UnexpectedResponse( + status_code=500, + reason_phrase="Internal Server Error", + content=b"Server error", + headers=httpx.Headers(), + ) + vector_instance._client.delete.side_effect = error + + # Should re-raise the exception + with pytest.raises(UnexpectedResponse) as exc_info: + vector_instance.delete_by_ids(ids) + + assert exc_info.value.status_code == 500 + + def test_delete_by_ids_with_large_batch(self, vector_instance): + """Test deletion with a large batch of IDs.""" + # Create 1000 IDs + ids = [f"doc_{i}" for i in range(1000)] + + vector_instance.delete_by_ids(ids) + + # Verify single delete call with all IDs + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + field_condition = filter_obj.must[0] + + # Verify all 1000 IDs are in the batch + assert len(field_condition.match.any) == 1000 + assert "doc_0" in field_condition.match.any + assert "doc_999" in field_condition.match.any + + def test_delete_by_ids_filter_structure(self, vector_instance): + """Test that the filter structure is correctly constructed.""" + ids = ["doc1", "doc2"] + + vector_instance.delete_by_ids(ids) + + call_args = vector_instance._client.delete.call_args + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + + # Verify Filter structure + assert isinstance(filter_obj, rest.Filter) + assert filter_obj.must is not None + assert len(filter_obj.must) == 1 + + # Verify FieldCondition structure + field_condition = filter_obj.must[0] + assert isinstance(field_condition, rest.FieldCondition) + assert field_condition.key == "metadata.doc_id" + + # Verify MatchAny structure + assert isinstance(field_condition.match, rest.MatchAny) + assert field_condition.match.any == ids diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 0000000000..951a920f3b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,443 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +@pytest.fixture +def tidb_module(): + import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + + return importlib.reload(module) + + +def _config(tidb_module): + return tidb_module.TiDBVectorConfig( + host="localhost", + port=4000, + user="root", + password="secret", + database="dify", + program_name="dify-app", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config TIDB_VECTOR_HOST is required"), + ("port", 0, "config TIDB_VECTOR_PORT is required"), + ("user", "", "config TIDB_VECTOR_USER is required"), + ("database", "", "config TIDB_VECTOR_DATABASE is required"), + ("program_name", "", "config APPLICATION_NAME is required"), + ], +) +def test_tidb_config_validation(tidb_module, field, value, message): + values = _config(tidb_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + tidb_module.TiDBVectorConfig.model_validate(values) + + +def test_init_get_type_and_distance_func(tidb_module, monkeypatch): + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine")) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2") + + assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR + assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify") + assert vector._dimension == 1536 + assert vector._get_distance_func() == "VEC_L2_DISTANCE" + + vector._distance_func = "cosine" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + vector._distance_func = "other" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + +def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch): + fake_tidb_vector = types.ModuleType("tidb_vector") + fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy") + + class _VectorType: + def __init__(self, dim): + self.dim = dim + + fake_tidb_sqlalchemy.VectorType = _VectorType + + monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector) + monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy) + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock())) + monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr( + tidb_module, + "Table", + lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns), + ) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module)) + table = vector._table(3) + + assert table.name == "collection_1" + column_names = [column.args[0] for column in table.columns] + assert tidb_module.Field.PRIMARY_KEY in column_names + assert tidb_module.Field.VECTOR in column_names + assert tidb_module.Field.TEXT_KEY in column_names + + +def test_create_calls_collection_and_add_texts(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + assert vector._dimension == 2 + + +def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + + tidb_module.Session = MagicMock() + + vector._create_collection(3) + + tidb_module.Session.assert_not_called() + tidb_module.redis_client.set.assert_not_called() + + +def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "l2" + + vector._create_collection(3) + + session.begin.assert_called_once() + sql = str(session.execute.call_args.args[0]) + assert "VECTOR(3)" in sql + assert "VEC_L2_DISTANCE" in sql + session.commit.assert_called_once() + tidb_module.redis_client.set.assert_called_once() + + +def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch): + class _InsertStmt: + def __init__(self, table): + self.table = table + + def values(self, rows): + return {"table": self.table, "rows": rows} + + monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table)) + + conn = MagicMock() + transaction = MagicMock() + transaction.__enter__.return_value = None + transaction.__exit__.return_value = None + conn.begin.return_value = transaction + + connection_ctx = MagicMock() + connection_ctx.__enter__.return_value = conn + connection_ctx.__exit__.return_value = None + + engine = MagicMock() + engine.connect.return_value = connection_ctx + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._engine = engine + vector._table = MagicMock(return_value="table") + + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)] + embeddings = [[float(i)] for i in range(501)] + + ids = vector.add_texts(docs, embeddings) + + assert ids[0] == "id-0" + assert len(ids) == 501 + assert conn.execute.call_count == 2 + + +@pytest.fixture +def tidb_vector_with_session(tidb_module, monkeypatch): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + return vector, session, tidb_module + + +# 1. search_by_full_text returns empty +def test_search_by_full_text_returns_empty(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + assert vector.search_by_full_text("query") == [] + + +# 2. text_exists returns True when ids found +def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + + +# 3. text_exists returns False when no ids +def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=None) + assert vector.text_exists("doc-1") is False + + +# 4. delete_by_ids delegates to _delete_by_ids when ids found +def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"]) + + +# 5. delete_by_ids skips when no ids found +def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-c"]) + vector._delete_by_ids.assert_not_called() + + +# 6. get_ids_by_metadata_field returns ids and returns None +def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + # Returns ids + session.execute.return_value.fetchall.return_value = [("id-1",)] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"] + # Returns None + session.execute.return_value.fetchall.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None + + +# 1. _delete_by_ids raises on None +def test__delete_by_ids_raises_on_none(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + with pytest.raises(ValueError, match="No ids provided"): + vector._delete_by_ids(None) + + +# 2. _delete_by_ids returns True and calls execute +def test__delete_by_ids_returns_true_and_calls_execute(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-1"]) is True + conn.execute.assert_called_once() + + +# 3. _delete_by_ids returns False on RuntimeError +def test__delete_by_ids_returns_false_on_runtime_error(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + conn.execute.side_effect = RuntimeError("delete failed") + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-2"]) is False + + +# 4. delete_by_metadata_field calls _delete_by_ids when ids found +def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-3") + vector._delete_by_ids.assert_called_once_with(["id-3"]) + + +# 5. delete_by_metadata_field does nothing when no ids +def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-4") + vector._delete_by_ids.assert_not_called() + + +# Test search_by_vector filters and scores +def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = [ + ('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2), + ('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4), + ] + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "cosine" + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.8) + assert docs[1].metadata["score"] == pytest.approx(0.6) + sql = str(session.execute.call_args.args[0]) + params = session.execute.call_args.kwargs["params"] + assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql + assert params["distance"] == pytest.approx(0.5) + assert params["top_k"] == 2 + session.commit.assert_not_called() + + +# Test delete drops table +def test_delete_drops_table(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = None + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector.delete() + drop_sql = str(session.execute.call_args.args[0]) + assert "DROP TABLE IF EXISTS collection_1" in drop_sql + session.commit.assert_called_once() + + +def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): + factory = tidb_module.TiDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000) + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify") + monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app") + + with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py new file mode 100644 index 0000000000..ac8a63a44b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,186 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_upstash_module(): + upstash_module = types.ModuleType("upstash_vector") + + class Vector: + def __init__(self, id, vector, metadata, data): + self.id = id + self.vector = vector + self.metadata = metadata + self.data = data + + class Index: + def __init__(self, url, token): + self.url = url + self.token = token + self.info = MagicMock(return_value=SimpleNamespace(dimension=8)) + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.reset = MagicMock() + + upstash_module.Vector = Vector + upstash_module.Index = Index + return upstash_module + + +@pytest.fixture +def upstash_module(monkeypatch): + # Remove patched modules if present + for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + if modname in sys.modules: + monkeypatch.delitem(sys.modules, modname, raising=False) + monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) + module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + return module + + +def _config(module): + return module.UpstashVectorConfig(url="https://upstash.example", token="token-123") + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("url", "", "Upstash URL is required"), + ("token", "", "Upstash Token is required"), + ], +) +def test_upstash_config_validation(upstash_module, field, value, message): + values = _config(upstash_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + upstash_module.UpstashVectorConfig.model_validate(values) + + +def test_init_get_type_and_dimension(upstash_module, monkeypatch): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + assert vector.get_type() == upstash_module.VectorType.UPSTASH + assert vector._table_name == "collection_1" + assert vector._get_index_dimension() == 8 + + vector.index.info.return_value = SimpleNamespace(dimension=None) + assert vector._get_index_dimension() == 1536 + + vector.index.info.return_value = None + assert vector._get_index_dimension() == 1536 + + monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid") + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.add_texts(docs, [[0.1, 0.2]]) + + vector.index.upsert.assert_called_once() + upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"] + assert upsert_vectors[0].id == "generated-uuid" + + +def test_create_text_exists_and_delete_by_ids(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + vector.get_ids_by_metadata_field.return_value = [] + assert vector.text_exists("doc-1") is False + + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]]) + vector._delete_by_ids = MagicMock() + vector.delete_by_ids(["doc-1", "doc-2", "doc-3"]) + vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"]) + + +def test_delete_helpers_and_search(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + vector._delete_by_ids([]) + vector.index.delete.assert_not_called() + vector._delete_by_ids(["a", "b"]) + vector.index.delete.assert_called_once_with(ids=["a", "b"]) + + vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")] + ids = vector.get_ids_by_metadata_field("doc_id", "doc-1") + assert ids == ["x-1", "x-2"] + query_kwargs = vector.index.query.call_args.kwargs + assert query_kwargs["top_k"] == 1000 + assert query_kwargs["filter"] == "doc_id = 'doc-1'" + + vector._delete_by_ids = MagicMock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + vector._delete_by_ids.assert_called_once_with(["x-1"]) + + vector._delete_by_ids.reset_mock() + vector.get_ids_by_metadata_field.return_value = [] + vector.delete_by_metadata_field("doc_id", "doc-2") + vector._delete_by_ids.assert_not_called() + + +def test_search_by_vector_filter_threshold_and_delete(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.index.query.return_value = [ + SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9), + SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3), + SimpleNamespace(metadata=None, data="text-3", score=0.99), + SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=3, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + search_kwargs = vector.index.query.call_args.kwargs + assert search_kwargs["top_k"] == 3 + assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')" + + assert vector.search_by_full_text("query") == [] + + vector.delete() + vector.index.reset.assert_called_once() + + +def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch): + factory = upstash_module.UpstashVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123") + + with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py new file mode 100644 index 0000000000..9da92af2d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py @@ -0,0 +1,310 @@ +import importlib +import json +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_vikingdb_modules(): + volcengine = types.ModuleType("volcengine") + volcengine.__path__ = [] + viking_db = types.ModuleType("volcengine.viking_db") + + class Data(UserDict): + def __init__(self, payload): + super().__init__(payload) + self.fields = payload + + class DistanceType: + L2 = "L2" + + class IndexType: + HNSW = "HNSW" + + class QuantType: + Float = "Float" + + class FieldType: + String = "string" + Text = "text" + Vector = "vector" + + class Field: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Collection: + def __init__(self): + self.upsert_data = MagicMock() + self.fetch_data = MagicMock(return_value=None) + self.delete_data = MagicMock() + + class _Index: + def __init__(self): + self.search = MagicMock(return_value=[]) + self.search_by_vector = MagicMock(return_value=[]) + + class VikingDBService: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_collection = MagicMock() + self.create_index = MagicMock() + self.drop_index = MagicMock() + self.drop_collection = MagicMock() + self._collection = _Collection() + self._index = _Index() + self.get_collection = MagicMock(return_value=self._collection) + self.get_index = MagicMock(return_value=self._index) + + viking_db.Data = Data + viking_db.DistanceType = DistanceType + viking_db.Field = Field + viking_db.FieldType = FieldType + viking_db.IndexType = IndexType + viking_db.QuantType = QuantType + viking_db.VectorIndexParams = VectorIndexParams + viking_db.VikingDBService = VikingDBService + + return {"volcengine": volcengine, "volcengine.viking_db": viking_db} + + +@pytest.fixture +def vikingdb_module(monkeypatch): + for name, module in _build_fake_vikingdb_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VikingDBConfig( + access_key="ak", + secret_key="sk", + host="host", + region="region", + scheme="https", + connection_timeout=10, + socket_timeout=20, + ) + + +def test_init_get_type_and_has_checks(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB + assert vector._index_name == "collection_1_idx" + + assert vector._has_collection() is True + assert vector._has_index() is True + + vector._client.get_collection.side_effect = RuntimeError("missing") + assert vector._has_collection() is False + vector._client.get_collection.side_effect = None + + vector._client.get_index.side_effect = RuntimeError("missing") + assert vector._has_index() is False + + +def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock()) + + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + vector._client.create_index.assert_not_called() + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None)) + vector._has_collection = MagicMock(return_value=False) + vector._has_index = MagicMock(return_value=False) + vector._create_collection(4) + + vector._client.create_collection.assert_called_once() + vector._client.create_index.assert_called_once() + vikingdb_module.redis_client.set.assert_called_once() + + +def test_create_and_add_texts(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module)) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}), + ] + vector.add_texts(docs, [[0.1], [0.2]]) + + vector._client.get_collection.assert_called() + upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0] + assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a" + assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2" + + +def test_text_exists_and_delete_operations(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"}) + assert vector.text_exists("id-1") is True + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace( + fields={"message": "data does not exist"} + ) + assert vector.text_exists("id-1") is False + + vector._client.get_collection.return_value.fetch_data.return_value = None + assert vector.text_exists("id-1") is False + + vector.delete_by_ids(["id-1"]) + vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-2"]) + + +def test_get_ids_and_search_helpers(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_index.return_value.search.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "x") == [] + + vector._client.get_index.return_value.search.return_value = [ + SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}), + SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}), + SimpleNamespace(id="c", fields={}), + ] + assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"] + + empty_docs = vector._get_search_res([], score_threshold=0.1) + assert empty_docs == [] + + results = [ + SimpleNamespace( + id="a", + score=0.3, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}), + }, + ), + SimpleNamespace( + id="b", + score=0.9, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}), + }, + ), + ] + + docs = vector._get_search_res(results, score_threshold=0.2) + assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"] + + vector._client.get_index.return_value.search_by_vector.return_value = results + filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"]) + assert len(filtered_docs) == 1 + assert filtered_docs[0].page_content == "doc-b" + assert vector.search_by_full_text("query") == [] + + +def test_delete_drops_index_and_collection_when_present(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._has_index = MagicMock(return_value=True) + vector._has_collection = MagicMock(return_value=True) + + vector.delete() + + vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx") + vector._client.drop_collection.assert_called_once_with("collection_1") + + vector._client.drop_index.reset_mock() + vector._client.drop_collection.reset_mock() + vector._has_index.return_value = False + vector._has_collection.return_value = False + vector.delete() + + vector._client.drop_index.assert_not_called() + vector._client.drop_collection.assert_not_called() + + +def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch): + factory = vikingdb_module.VikingDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls: + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10) + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20) + + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +@pytest.mark.parametrize( + ("field", "message"), + [ + ("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"), + ("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"), + ("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"), + ("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"), + ("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"), + ], +) +def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message): + factory = vikingdb_module.VikingDBVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None + ) + + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, field, None) + + with pytest.raises(ValueError, match=message): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py index 3bd656ba84..69d1833001 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py @@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in: - Full-text search result metadata (search_by_full_text) """ +import datetime +import json import unittest from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase): def tearDown(self): weaviate_vector_module._weaviate_client = None + def test_config_requires_endpoint(self): + with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): + WeaviateConfig(endpoint="") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" @@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase): ) return wv, mock_client + def test_shutdown_client_logs_debug_when_close_fails(self): + mock_client = MagicMock() + mock_client.close.side_effect = RuntimeError("close failed") + weaviate_vector_module._weaviate_client = mock_client + + with patch.object(weaviate_vector_module.logger, "debug") as mock_debug: + weaviate_vector_module._shutdown_weaviate_client() + + assert weaviate_vector_module._weaviate_client is None + mock_client.close.assert_called_once() + mock_debug.assert_called_once() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.return_value = True + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.side_effect = [False, True] + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + config = WeaviateConfig( + endpoint="https://weaviate.example.com", + grpc_endpoint="grpc.example.com:6000", + api_key="test-key", + batch_size=50, + ) + + client = wv._init_client(config) + + assert client is mock_client + assert mock_connect.call_args.kwargs == { + "http_host": "weaviate.example.com", + "http_port": 443, + "http_secure": True, + "grpc_host": "grpc.example.com", + "grpc_port": 6000, + "grpc_secure": False, + "auth_credentials": "auth-token", + "skip_init_checks": True, + } + mock_api_key.assert_called_once_with("test-key") + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_raises_when_database_not_ready(self, mock_connect): + mock_client = MagicMock() + mock_client.is_ready.return_value = False + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + + with pytest.raises(ConnectionError, match="Vector database is not ready"): + wv._init_client(self.config) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" @@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase): assert wv._collection_name == self.collection_name assert "doc_type" in wv._attributes + def test_get_type_and_to_index_struct(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + + assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE + assert wv.to_index_struct() == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": self.collection_name}, + } + + def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self): + dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1") + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv.get_collection_name(dataset) == "ExistingCollection_Node" + + def test_get_collection_name_generates_name_from_dataset_id(self): + dataset = SimpleNamespace(index_struct_dict=None, id="ds-2") + wv = WeaviateVector.__new__(WeaviateVector) + + with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"): + assert wv.get_collection_name(dataset) == "Generated_Node" + + def test_create_calls_collection_setup_then_add_texts(self): + doc = Document(page_content="hello", metadata={}) + wv = WeaviateVector.__new__(WeaviateVector) + wv._create_collection = MagicMock() + wv.add_texts = MagicMock() + + wv.create([doc], [[0.1, 0.2]]) + + wv._create_collection.assert_called_once() + wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._ensure_properties = MagicMock() + + wv._create_collection() + + wv._client.collections.exists.assert_not_called() + wv._ensure_properties.assert_not_called() + mock_redis.set.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_logs_and_reraises_errors(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock(return_value=False) + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = RuntimeError("create failed") + + with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: + with pytest.raises(RuntimeError, match="create failed"): + wv._create_collection() + + mock_exception.assert_called_once() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" @@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [SimpleNamespace(name="text")] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + add_calls = mock_col.config.add_property.call_args_list + added_names = [call.args[0].name for call in add_calls] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" @@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [] + mock_col.config.get.return_value = mock_cfg + mock_col.config.add_property.side_effect = RuntimeError("cannot add") + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: + wv._ensure_properties() + + assert mock_warning.call_count == 4 + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = { + "text": "fallback distance result", + "document_id": "doc-1", + "doc_id": "segment-1", + } + mock_obj.metadata = None + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.near_vector.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector( + query_vector=[0.2] * 3, + document_ids_filter=["doc-1"], + top_k=2, + score_threshold=-1, + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.0 + assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_vector(query_vector=[0.1] * 3) == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" @@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"} + mock_obj.vector = [0.3, 0.4] + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.bm25.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].vector == [0.3, 0.4] + assert mock_col.query.bm25.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_full_text(query="missing") == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" @@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_batch = MagicMock() + mock_batch.__enter__ = MagicMock(return_value=mock_batch) + mock_batch.__exit__ = MagicMock(return_value=False) + mock_col.batch.dynamic.return_value = mock_batch + + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + doc = Document(page_content="text", metadata={"created_at": created_at}) + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with ( + patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), + patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + ): + ids = wv.add_texts(documents=[doc], embeddings=[[]]) + + assert ids == ["fallback-uuid"] + call_kwargs = mock_batch.add_object.call_args + assert call_kwargs.kwargs["uuid"] == "fallback-uuid" + assert call_kwargs.kwargs["vector"] is None + assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat() + + def test_is_uuid_handles_invalid_values(self): + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True + assert wv._is_uuid("not-a-uuid") is False + + def test_delete_by_metadata_field_returns_when_collection_is_missing(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = False + + wv.delete_by_metadata_field("doc_id", "segment-1") + + wv._client.collections.use.assert_not_called() + + def test_delete_by_metadata_field_deletes_matching_objects(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + + wv.delete_by_metadata_field("doc_id", "segment-1") + + mock_col.data.delete_many.assert_called_once() + + def test_delete_removes_collection_when_present(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + + wv.delete() + + wv._client.collections.delete.assert_called_once_with(self.collection_name) + + def test_text_exists_handles_missing_and_present_documents(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()]) + + assert wv.text_exists("segment-1") is False + assert wv.text_exists("segment-1") is True + + def test_delete_by_ids_handles_missing_collections_and_404s(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None] + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + wv.delete_by_ids(["ignored"]) + wv.delete_by_ids(["missing-id", "ok-id"]) + + assert mock_col.data.delete_by_id.call_count == 2 + + def test_delete_by_ids_reraises_non_404_errors(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): + wv.delete_by_ids(["bad-id"]) + + def test_json_serializable_converts_datetime(self): + wv = WeaviateVector.__new__(WeaviateVector) + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + + assert wv._json_serializable(created_at) == created_at.isoformat() + assert wv._json_serializable("plain") == "plain" + class TestVectorDefaultAttributes(unittest.TestCase): """Tests for Vector class default attributes list.""" @@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase): assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" +class TestWeaviateVectorFactory(unittest.TestCase): + def test_init_vector_uses_existing_dataset_index_struct(self): + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}}, + index_struct=None, + ) + attributes = ["doc_id"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + config = mock_vector.call_args.kwargs["config"] + assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node" + assert mock_vector.call_args.kwargs["attributes"] == attributes + assert config.endpoint == "http://localhost:8080" + assert config.grpc_endpoint == "localhost:50051" + assert config.api_key == "api-key" + assert config.batch_size == 88 + assert dataset.index_struct is None + + def test_init_vector_generates_collection_and_updates_index_struct(self): + dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + attributes = ["doc_id", "doc_type"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100), + patch.object( + weaviate_vector_module.Dataset, + "gen_collection_name_by_id", + return_value="GeneratedCollection_Node", + ), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node" + assert json.loads(dataset.index_struct) == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "GeneratedCollection_Node"}, + } + + if __name__ == "__main__": unittest.main() diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index d3040395be..db49221583 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -104,10 +104,11 @@ class TestFirecrawlApp: def test_map_known_error(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") - mock_handle = mocker.patch.object(app, "_handle_error") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error")) mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"})) - assert app.map("https://example.com") == {} + with pytest.raises(Exception, match="map error"): + app.map("https://example.com") mock_handle.assert_called_once() def test_map_unknown_error_raises(self, mocker: MockerFixture): @@ -163,6 +164,13 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="No page found"): app.check_crawl_status("job-1") + def test_check_crawl_status_completed_with_null_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": None, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") payload = {"status": "processing", "total": 5, "completed": 1, "data": []} @@ -177,10 +185,11 @@ class TestFirecrawlApp: def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") - mock_handle = mocker.patch.object(app, "_handle_error") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error")) mocker.patch("httpx.get", return_value=_response(500, {"error": "server"})) - assert app.check_crawl_status("job-1") == {} + with pytest.raises(Exception, match="crawl error"): + app.check_crawl_status("job-1") mock_handle.assert_called_once() def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture): @@ -201,6 +210,77 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="Error saving crawl data"): app.check_crawl_status("job-err") + def test_check_crawl_status_follows_pagination(self, mocker: MockerFixture): + """When status is completed and next is present, follow pagination to collect all pages.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + page2 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=2", + "data": [{"metadata": {"title": "p2", "description": "", "sourceURL": "https://p2"}, "markdown": "m2"}], + } + page3 = { + "status": "completed", + "total": 3, + "completed": 3, + "data": [{"metadata": {"title": "p3", "description": "", "sourceURL": "https://p3"}, "markdown": "m3"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(200, page2), _response(200, page3)]) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert [d["title"] for d in result["data"]] == ["p1", "p2", "p3"] + + def test_check_crawl_status_pagination_error_raises(self, mocker: MockerFixture): + """An error while fetching a paginated page raises an exception; no partial data is returned.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 2, + "completed": 2, + "next": "https://custom.firecrawl.dev/v2/crawl/job-99?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(500, {"error": "server error"})]) + + with pytest.raises(Exception, match="fetch next crawl page"): + app.check_crawl_status("job-99") + + def test_check_crawl_status_pagination_capped_at_total(self, mocker: MockerFixture): + """Pagination stops once pages_processed reaches total, even if next is present.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + # total=1: only the first page should be processed; next must not be followed + page1 = { + "status": "completed", + "total": 1, + "completed": 1, + "next": "https://custom.firecrawl.dev/v2/crawl/job-cap?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mock_get = mocker.patch("httpx.get", return_value=_response(200, page1)) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-cap") + + assert len(result["data"]) == 1 + mock_get.assert_called_once() # initial fetch only; next URL is not followed due to cap + def test_extract_common_fields_and_status_formatter(self): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") @@ -272,9 +352,10 @@ class TestFirecrawlApp: def test_search_known_http_error(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") - mock_handle = mocker.patch.object(app, "_handle_error") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error")) mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"})) - assert app.search("python") == {} + with pytest.raises(Exception, match="search error"): + app.search("python") mock_handle.assert_called_once() def test_search_unknown_http_error(self, mocker: MockerFixture): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 2451db70b6..2c234edd9a 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -21,7 +22,7 @@ class TestParagraphIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -167,7 +168,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_with_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -178,7 +179,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_without_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -208,7 +209,7 @@ class TestParagraphIndexProcessor: def test_clean_economy_deletes_summaries_and_keywords( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( @@ -222,7 +223,7 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete.assert_called_once() def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: processor.clean(dataset, ["node-2"], with_keywords=True) @@ -236,7 +237,8 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" ) as mock_retrieve: mock_retrieve.return_value = [accepted, rejected] - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].metadata["score"] == 0.9 @@ -266,7 +268,7 @@ class TestParagraphIndexProcessor: def test_index_list_chunks_economy( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index abe40f05d1..b1ed735ee7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -19,7 +20,7 @@ class TestParentChildIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -307,7 +308,8 @@ class TestParentChildIndexProcessor: "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" ) as mock_retrieve: mock_retrieve.return_value = [ok_result, low_result] - docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].page_content == "keep" diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 8596647ef3..98c47bec8f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -6,6 +6,7 @@ import pytest from werkzeug.datastructures import FileStorage from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.models.document import AttachmentDocument, Document @@ -33,7 +34,7 @@ class TestQAIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,7 +208,7 @@ class TestQAIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="Q1", metadata={"answer": "A1"})] with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: @@ -262,7 +263,8 @@ class TestQAIndexProcessor: with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: mock_retrieve.return_value = [result_ok, result_low] - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].page_content == "accepted" @@ -297,7 +299,7 @@ class TestQAIndexProcessor: def test_index_requires_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) with ( diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade884..b54a74b69c 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,7 +61,7 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now @@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument def create_mock_dataset( dataset_id: str | None = None, tenant_id: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", ) -> Mock: @@ -458,7 +458,7 @@ class TestIndexingRunnerTransform: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -521,7 +521,7 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() transformed_docs = [ @@ -605,7 +605,7 @@ class TestIndexingRunnerLoad: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -674,7 +674,7 @@ class TestIndexingRunnerLoad: """Test loading with economy indexing (keyword only).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -701,7 +701,7 @@ class TestIndexingRunnerLoad: # Arrange runner = IndexingRunner() sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - sample_dataset.indexing_technique = "high_quality" + sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # Add child documents for doc in sample_documents: @@ -795,7 +795,7 @@ class TestIndexingRunnerRun: mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) @@ -949,7 +949,7 @@ class TestIndexingRunnerRun: mock_dependencies["db"].session.get.side_effect = get_side_effect mock_dataset = Mock(spec=Dataset) - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index d61f01c616..a34ca330ca 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -25,6 +25,7 @@ from core.app.app_config.entities import ModelConfig as WorkflowModelConfig from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.rag.data_post_processor.data_post_processor import WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -4686,7 +4687,10 @@ class TestSingleAndMultipleRetrieveCoverage: extra={"dataset_name": "Ext", "title": "Ext"}, ) app = Flask(__name__) - weights = {"vector_setting": {}} + weights: WeightsDict = { + "vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""}, + "keyword_setting": {"keyword_weight": 0.5}, + } def fake_multiple_thread(**kwargs): if kwargs["query"]: @@ -4796,8 +4800,8 @@ class TestInternalHooksCoverage: dataset_docs = [ SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), - SimpleNamespace(id="doc-c", doc_form="qa_model"), - SimpleNamespace(id="doc-d", doc_form="qa_model"), + SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX), + SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX), ] child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py new file mode 100644 index 0000000000..4116e8b4a5 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -0,0 +1,677 @@ +from __future__ import annotations + +import dataclasses +import json +from collections.abc import Sequence +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormRepositoryImpl, + HumanInputFormSubmissionRepository, + _HumanInputFormEntityImpl, + _HumanInputFormRecipientEntityImpl, + _InvalidTimeoutStatusError, + _WorkspaceMemberInfo, +) +from dify_graph.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, + UserAction, + WebAppDeliveryMethod, +) +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError +from libs.datetime_utils import naive_utc_now +from models.human_input import HumanInputFormRecipient, RecipientType + + +@pytest.fixture(autouse=True) +def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeSelect: + def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect()) + monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader") + + +def _make_form_definition_json(*, include_expiration_time: bool) -> str: + payload: dict[str, Any] = { + "form_content": "hi", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "rendered_content": "

hi

", + } + if include_expiration_time: + payload["expiration_time"] = naive_utc_now() + return json.dumps(payload, default=str) + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str | None + node_id: str + tenant_id: str + app_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + created_at: datetime = dataclasses.field(default_factory=naive_utc_now) + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + status: HumanInputFormStatus = HumanInputFormStatus.WAITING + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str | None + + +class _FakeScalarResult: + def __init__(self, obj: Any): + self._obj = obj + + def first(self) -> Any: + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None + return self._obj + + def all(self) -> list[Any]: + if self._obj is None: + return [] + if isinstance(self._obj, list): + return list(self._obj) + return [self._obj] + + +class _FakeExecuteResult: + def __init__(self, rows: Sequence[tuple[Any, ...]]): + self._rows = list(rows) + + def all(self) -> list[tuple[Any, ...]]: + return list(self._rows) + + +class _FakeSession: + def __init__( + self, + *, + scalars_result: Any = None, + scalars_results: list[Any] | None = None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + execute_rows: Sequence[tuple[Any, ...]] = (), + ): + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + else: + self._scalars_queue = [scalars_result] + self._forms = forms or {} + self._recipients = recipients or {} + self._execute_rows = list(execute_rows) + self.added: list[Any] = [] + + def scalars(self, _query: Any) -> _FakeScalarResult: + if self._scalars_queue: + value = self._scalars_queue.pop(0) + else: + value = None + return _FakeScalarResult(value) + + def execute(self, _stmt: Any) -> _FakeExecuteResult: + return _FakeExecuteResult(self._execute_rows) + + def get(self, model_cls: Any, obj_id: str) -> Any: + name = getattr(model_cls, "__name__", "") + if name == "HumanInputForm": + return self._forms.get(obj_id) + if name == "HumanInputFormRecipient": + return self._recipients.get(obj_id) + return None + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: Sequence[Any]) -> None: + self.added.extend(list(objs)) + + def flush(self) -> None: + # Simulate DB default population for attributes referenced in entity wrappers. + for obj in self.added: + if hasattr(obj, "id") and obj.id in (None, ""): + obj.id = f"gen-{len(str(self.added))}" + if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None: + if obj.recipient_type == RecipientType.CONSOLE: + obj.access_token = "token-console" + elif obj.recipient_type == RecipientType.BACKSTAGE: + obj.access_token = "token-backstage" + else: + obj.access_token = "token-webapp" + + def refresh(self, _obj: Any) -> None: + return None + + def begin(self) -> _FakeSession: + return self + + def __enter__(self) -> _FakeSession: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +class _SessionFactoryStub: + def __init__(self, session: _FakeSession): + self._session = session + + def create_session(self) -> _FakeSession: + return self._session + + +def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session)) + + +def test_recipient_entity_token_raises_when_missing() -> None: + recipient = SimpleNamespace(id="r1", access_token=None) + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + with pytest.raises(AssertionError, match="access_token should not be None"): + _ = entity.token + + +def test_recipient_entity_id_and_token_success() -> None: + recipient = SimpleNamespace(id="r1", access_token="tok") + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + assert entity.id == "r1" + assert entity.token == "tok" + + +def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok") + webapp = _DummyRecipient( + id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok" + ) + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] + assert entity.web_app_token == "ctok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] + assert entity.web_app_token == "wtok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.web_app_token is None + + +def test_form_entity_submitted_data_parsed() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + submitted_data='{"a": 1}', + submitted_at=naive_utc_now(), + ) + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.submitted is True + assert entity.submitted_data == {"a": 1} + assert entity.rendered_content == "

x

" + assert entity.selected_action_id is None + assert entity.status == HumanInputFormStatus.WAITING + + +def test_form_record_from_models_injects_expiration_time_when_missing() -> None: + expiration = naive_utc_now() + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=False), + rendered_content="

x

", + expiration_time=expiration, + submitted_data='{"k": "v"}', + ) + record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type] + assert record.definition.expiration_time == expiration + assert record.submitted_data == {"k": "v"} + assert record.submitted is False + + +def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + id=f"{payload.TYPE}-{len(created)}", + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + access_token="tok", + ) + created.append(recipient) + return recipient + + monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new)) + + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined] + form_id="f", + delivery_id="d", + members=[ + _WorkspaceMemberInfo(user_id="u1", email=""), + _WorkspaceMemberInfo(user_id="u2", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u3", email="a@example.com"), + ], + external_emails=["", "a@example.com", "b@example.com", "b@example.com"], + ) + assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL] + + +def test_query_workspace_members_by_ids_empty_returns_empty() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == [] + + +def test_query_workspace_members_by_ids_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"]) + assert rows == [ + _WorkspaceMemberInfo(user_id="u1", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u2", email="b@example.com"), + ] + + +def test_query_all_workspace_members_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_all_workspace_members(session=session) + assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + +def test_repository_init_sets_tenant_id() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._tenant_id == "tenant" + + +def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + result = repo._delivery_method_to_model( + session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod() + ) + assert result.delivery.id == "del-1" + assert result.delivery.form_id == "form-1" + assert len(result.recipients) == 1 + assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP + + +def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + called: dict[str, Any] = {} + + def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]: + called.update( + {"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config} + ) + return ["r"] + + monkeypatch.setattr(repo, "_build_email_recipients", fake_build) + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + ), + subject="s", + body="b", + ) + ) + result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method) + assert result.recipients == ["r"] + assert called["delivery_id"] == "del-1" + + +def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr( + repo, + "_query_all_workspace_members", + lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")], + ) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + ) + assert recipients == ["ok"] + + +def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]: + assert restrict_to_user_ids == ["u1"] + return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + ), + ) + assert recipients == ["ok"] + + +def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo.get_form("run", "node") is None + + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + entity = repo.get_form("run", "node") + assert entity is not None + assert entity.id == "f1" + assert entity.recipients[0].id == "r1" + assert entity.recipients[0].token == "tok" + + +def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + ids = iter(["form-id", "del-web", "del-console", "del-backstage"]) + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids)) + + session = _FakeSession() + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + form_config = HumanInputNodeData( + title="Title", + delivery_methods=[], + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + ) + params = FormCreateParams( + app_id="app", + workflow_execution_id="run", + node_id="node", + form_config=form_config, + rendered_content="

hello

", + delivery_methods=[WebAppDeliveryMethod()], + display_in_ui=True, + resolved_default_values={}, + form_kind=HumanInputFormKind.RUNTIME, + console_recipient_required=True, + console_creator_account_id="acc-1", + backstage_recipient_required=True, + ) + + entity = repo.create_form(params) + assert entity.id == "form-id" + assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) + # Console token should take precedence when console recipient is present. + assert entity.web_app_token == "token-console" + assert len(entity.recipients) == 3 + + +def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + recipient = SimpleNamespace(form=None) + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + +def test_submission_repository_init_no_args() -> None: + repo = HumanInputFormSubmissionRepository() + assert isinstance(repo, HumanInputFormSubmissionRepository) + + +def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + recipient = SimpleNamespace( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + form=form, + ) + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_token("tok") + assert record is not None + assert record.access_token == "tok" + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP) + assert record is not None + assert record.recipient_id == "r1" + + +def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None + + +def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + missing_session = _FakeSession(forms={}) + _patch_session_factory(monkeypatch, missing_session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_submitted( + form_id="missing", + recipient_id=None, + selected_action_id="a", + form_data={}, + submission_user_id=None, + submission_end_user_id=None, + ) + + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok") + session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"k": "v"}, + submission_user_id="u", + submission_end_user_id="eu", + ) + assert form.status == HumanInputFormStatus.SUBMITTED + assert form.submitted_at == fixed_now + assert record.submitted_data == {"k": "v"} + + +def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type] + + +def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.TIMEOUT, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r") + assert record.status == HumanInputFormStatus.TIMEOUT + + +def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.SUBMITTED, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form already submitted"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + + +def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + selected_action_id="a", + submitted_data="{}", + submission_user_id="u", + submission_end_user_id="eu", + completed_by_recipient_id="r", + status=HumanInputFormStatus.WAITING, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + assert form.status == HumanInputFormStatus.EXPIRED + assert form.selected_action_id is None + assert form.submitted_data is None + assert form.submission_user_id is None + assert form.submission_end_user_id is None + assert form.completed_by_recipient_id is None + assert record.status == HumanInputFormStatus.EXPIRED + + +def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(forms={})) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT) diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index c66e50437a..232ab07882 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -1,84 +1,291 @@ -from datetime import datetime +from datetime import UTC, datetime from unittest.mock import MagicMock from uuid import uuid4 -from sqlalchemy import create_engine +import pytest +from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType -from models import Account, WorkflowRun +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom -def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository: - engine = create_engine("sqlite:///:memory:") - real_session_factory = sessionmaker(bind=engine, expire_on_commit=False) - - user = MagicMock(spec=Account) - user.id = str(uuid4()) - user.current_tenant_id = str(uuid4()) - - repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=real_session_factory, - user=user, - app_id="app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = False - repository._session_factory = MagicMock(return_value=session_context) - return repository - - -def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution: - return WorkflowExecution.new( - id_=execution_id, - workflow_id="workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "hello"}, - started_at=started_at, - ) - - -def test_save_uses_execution_started_at_when_record_does_not_exist(): +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + session_factory = MagicMock(spec=sessionmaker) session = MagicMock() session.get.return_value = None - repository = _build_repository_with_mocked_session(session) - - started_at = datetime(2026, 1, 1, 12, 0, 0) - execution = _build_execution(execution_id=str(uuid4()), started_at=started_at) - - repository.save(execution) - - saved_model = session.merge.call_args.args[0] - assert saved_model.created_at == started_at - session.commit.assert_called_once() + session_factory.return_value.__enter__.return_value = session + return session_factory -def test_save_preserves_existing_created_at_when_record_already_exists(): - session = MagicMock() - repository = _build_repository_with_mocked_session(session) +@pytest.fixture +def mock_engine(): + """Mock SQLAlchemy Engine.""" + return MagicMock(spec=Engine) - execution_id = str(uuid4()) - existing_created_at = datetime(2026, 1, 1, 12, 0, 0) - existing_run = WorkflowRun() - existing_run.id = execution_id - existing_run.tenant_id = repository._tenant_id - existing_run.created_at = existing_created_at - session.get.return_value = existing_run - execution = _build_execution( - execution_id=execution_id, - started_at=datetime(2026, 1, 1, 12, 30, 0), +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = MagicMock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = MagicMock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_execution(): + """Sample WorkflowExecution for testing.""" + return WorkflowExecution( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + outputs={"output1": "result1"}, + status=WorkflowExecutionStatus.SUCCEEDED, + error_message="", + total_tokens=100, + total_steps=5, + exceptions_count=0, + started_at=datetime.now(UTC), + finished_at=datetime.now(UTC), ) - repository.save(execution) - saved_model = session.merge.call_args.args[0] - assert saved_model.created_at == existing_created_at - session.commit.assert_called_once() +class TestSQLAlchemyWorkflowExecutionRepository: + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + app_id = "test_app_id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from + ) + + assert repo._session_factory == mock_session_factory + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._creator_user_role == CreatorUserRole.ACCOUNT + + def test_init_with_engine(self, mock_engine, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_engine, + user=mock_account, + app_id="test_app_id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + assert isinstance(repo._session_factory, sessionmaker) + assert repo._session_factory.kw["bind"] == mock_engine + + def test_init_invalid_session_factory(self, mock_account): + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowExecutionRepository( + session_factory="invalid", user=mock_account, app_id=None, triggered_from=None + ) + + def test_init_no_tenant_id(self, mock_session_factory): + user = MagicMock(spec=Account) + user.current_tenant_id = None + + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None + ) + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None + ) + assert repo._tenant_id == mock_end_user.tenant_id + assert repo._creator_user_role == CreatorUserRole.END_USER + + def test_to_domain_model(self, mock_session_factory, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + db_model = MagicMock(spec=WorkflowRun) + db_model.id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.type = "workflow" + db_model.version = "1.0" + db_model.inputs_dict = {"in": "val"} + db_model.outputs_dict = {"out": "val"} + db_model.graph_dict = {"nodes": []} + db_model.status = "succeeded" + db_model.error = "some error" + db_model.total_tokens = 50 + db_model.total_steps = 3 + db_model.exceptions_count = 1 + db_model.created_at = datetime.now(UTC) + db_model.finished_at = datetime.now(UTC) + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.id_ == db_model.id + assert domain_model.workflow_id == db_model.workflow_id + assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED + assert domain_model.inputs == db_model.inputs_dict + assert domain_model.error_message == "some error" + + def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + + # Make elapsed time deterministic to avoid flaky tests + sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC) + sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC) + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.id == sample_workflow_execution.id_ + assert db_model.tenant_id == repo._tenant_id + assert db_model.app_id == "test_app" + assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING + assert db_model.status == sample_workflow_execution.status.value + assert db_model.total_tokens == sample_workflow_execution.total_tokens + assert db_model.elapsed_time == 10.0 + + def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Test with empty/None fields + sample_workflow_execution.graph = None + sample_workflow_execution.inputs = None + sample_workflow_execution.outputs = None + sample_workflow_execution.error_message = None + sample_workflow_execution.finished_at = None + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.graph is None + assert db_model.inputs is None + assert db_model.outputs is None + assert db_model.error is None + assert db_model.elapsed_time == 0 + + def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=None, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + db_model = repo._to_db_model(sample_workflow_execution) + assert not hasattr(db_model, "app_id") or db_model.app_id is None + assert db_model.tenant_id == repo._tenant_id + + def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + # Test triggered_from missing + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(sample_workflow_execution) + + repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(sample_workflow_execution) + + repo._creator_user_id = "some_id" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(sample_workflow_execution) + + def test_save(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + repo.save(sample_workflow_execution) + + session = mock_session_factory.return_value.__enter__.return_value + session.merge.assert_called_once() + session.commit.assert_called_once() + + # Check cache + assert sample_workflow_execution.id_ in repo._execution_cache + cached_model = repo._execution_cache[sample_workflow_execution.id_] + assert cached_model.id == sample_workflow_execution.id_ + + def test_save_uses_execution_started_at_when_record_does_not_exist( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + sample_workflow_execution.started_at = started_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = None + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + def test_save_preserves_existing_created_at_when_record_already_exists( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + execution_id = sample_workflow_execution.id_ + existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repo._tenant_id + existing_run.created_at = existing_created_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = existing_run + + sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC) + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 0000000000..73de15e2cf --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock + +import psycopg2.errors +import pytest +from sqlalchemy import Engine, create_engine +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories.sqlalchemy_workflow_node_execution_repository import ( + SQLAlchemyWorkflowNodeExecutionRepository, + _deterministic_json_dump, + _filter_by_offload_type, + _find_first, + _replace_or_append_offload, +) +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from models import Account, EndUser +from models.enums import ExecutionOffLoadType +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom + + +def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account: + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + return user + + +def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser: + user = Mock(spec=EndUser) + user.id = user_id + user.tenant_id = tenant_id + return user + + +def _execution( + *, + execution_id: str = "exec-id", + node_execution_id: str = "node-exec-id", + workflow_run_id: str = "run-id", + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED, + inputs: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, +) -> WorkflowNodeExecution: + return WorkflowNodeExecution( + id=execution_id, + node_execution_id=node_execution_id, + workflow_id="workflow-id", + workflow_execution_id=workflow_run_id, + index=1, + predecessor_node_id=None, + node_id="node-id", + node_type=BuiltinNodeTypes.LLM, + title="Title", + inputs=inputs, + outputs=outputs, + process_data=process_data, + status=status, + error=None, + elapsed_time=1.0, + metadata=metadata, + created_at=datetime.now(UTC), + finished_at=None, + ) + + +class _SessionCtx: + def __init__(self, session: Any): + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _session_factory(session: Any) -> sessionmaker: + factory = Mock(spec=sessionmaker) + factory.return_value = _SessionCtx(session) + return factory + + +def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + engine: Engine = create_engine("sqlite:///:memory:") + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=engine, + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert isinstance(repo._session_factory, sessionmaker) + + sm = Mock(spec=sessionmaker) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sm, + user=_mock_end_user(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + assert repo._creator_user_role.value == "end_user" + + +def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type] + session_factory=object(), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + user = _mock_account() + user.current_tenant_id = None + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=user, + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None: + created: dict[str, Any] = {} + + class FakeTruncator: + def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int): + created.update( + { + "max_size_bytes": max_size_bytes, + "array_element_limit": array_element_limit, + "string_length_limit": string_length_limit, + } + ) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator", + FakeTruncator, + ) + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + _ = repo._create_truncator() + assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE + + +def test_helpers_find_first_and_replace_or_append_and_filter() -> None: + assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}' + assert _find_first([], lambda _: True) is None + assert _find_first([1, 2, 3], lambda x: x > 1) == 2 + + off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2 + + replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)) + assert len(replaced) == 2 + assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS] + + +def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1}) + + # Happy path: deterministic json dump should be sorted + db_model = repo._to_db_model(execution) + assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1} + assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1 + + repo._triggered_from = None + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(execution) + + +def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution() + db_model = repo._to_db_model(execution) + assert db_model.app_id == "app" + + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(execution) + + repo._creator_user_id = "user" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(execution) + + +def test_is_duplicate_key_error_and_regenerate_id( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + assert repo._is_duplicate_key_error(duplicate_error) is True + assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False + + execution = _execution(execution_id="old-id") + db_model = WorkflowNodeExecutionModel() + db_model.id = "old-id" + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + caplog.set_level(logging.WARNING) + repo._regenerate_id_on_duplicate(execution, db_model) + assert execution.id == "new-id" + assert db_model.id == "new-id" + assert any("Duplicate key conflict" in r.message for r in caplog.records) + + +def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id1" + db_model.node_execution_id = "node1" + db_model.foo = "bar" # type: ignore[attr-defined] + db_model.__dict__["_private"] = "x" + + existing = SimpleNamespace() + session.get.return_value = existing + repo._persist_to_database(db_model) + assert existing.foo == "bar" + session.add.assert_not_called() + assert repo._node_execution_cache["node1"] is db_model + + session.reset_mock() + session.get.return_value = None + repo._node_execution_cache.clear() + repo._persist_to_database(db_model) + session.add.assert_called_once_with(db_model) + assert repo._node_execution_cache["node1"] is db_model + + +def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return value, False + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None + + +def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None: + uploaded: dict[str, Any] = {} + + class FakeFileService: + def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def] + uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user}) + return SimpleNamespace(id="file-id", key="file-key") + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService() + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id") + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return {"truncated": True}, True + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + + result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS) + assert result is not None + assert result.truncated_value == {"truncated": True} + assert uploaded["filename"].startswith("node_execution_exec_inputs.json") + assert result.offload.file_id == "file-id" + assert result.offload.type_ == ExecutionOffLoadType.INPUTS + + +def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = BuiltinNodeTypes.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"trunc": "i"}) + db_model.process_data = json.dumps({"trunc": "p"}) + db_model.outputs = json.dumps({"trunc": "o"}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = json.dumps({"total_tokens": 3}) + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + + off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA) + off_in.file = SimpleNamespace(key="k-in") + off_out.file = SimpleNamespace(key="k-out") + off_proc.file = SimpleNamespace(key="k-proc") + db_model.offload_data = [off_out, off_in, off_proc] + + def fake_load(key: str) -> bytes: + return json.dumps({"full": key}).encode() + + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load) + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"full": "k-in"} + assert domain.outputs == {"full": "k-out"} + assert domain.process_data == {"full": "k-proc"} + assert domain.get_truncated_inputs() == {"trunc": "i"} + assert domain.get_truncated_outputs() == {"trunc": "o"} + assert domain.get_truncated_process_data() == {"trunc": "p"} + + +def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = BuiltinNodeTypes.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"i": 1}) + db_model.process_data = json.dumps({"p": 2}) + db_model.outputs = json.dumps({"o": 3}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + db_model.offload_data = [] + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"i": 1} + assert domain.outputs == {"o": 3} + + +def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeConverter: + def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]: + return {"wrapped": values["a"]} + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter", + FakeConverter, + ) + assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}' + + +def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace( + id="id", + offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)], + inputs=None, + outputs=None, + process_data=None, + ) + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + trunc_result = SimpleNamespace( + truncated_value={"trunc": True}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"), + ) + monkeypatch.setattr( + repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None + ) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + # Inputs should be truncated, outputs/process_data encoded directly + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.inputs) == {"trunc": True} + assert json.loads(db_model.outputs) == {"b": 2} + assert json.loads(db_model.process_data) == {"c": 3} + assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data) + assert execution.get_truncated_inputs() == {"trunc": True} + + +def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + existing = SimpleNamespace( + id="id", + offload_data=[], + inputs=None, + outputs=None, + process_data=None, + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = existing + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any: + if values == {"b": 2}: + return SimpleNamespace( + truncated_value={"b": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"), + ) + if values == {"c": 3}: + return SimpleNamespace( + truncated_value={"c": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"), + ) + return None + + monkeypatch.setattr(repo, "_truncate_and_upload", trunc) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.outputs) == {"b": "trunc"} + assert json.loads(db_model.process_data) == {"c": "trunc"} + assert execution.get_truncated_outputs() == {"b": "trunc"} + assert execution.get_truncated_process_data() == {"c": "trunc"} + + +def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = None + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}) + fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None) + monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model) + monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values)) + + repo.save_execution_data(execution) + merged = session.merge.call_args.args[0] + assert merged.inputs == '{"a": 1}' + + +def test_save_retries_duplicate_and_logs_non_duplicate( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(execution_id="id") + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + other_error = IntegrityError("other", params=None, orig=None) + + calls = {"n": 0} + + def persist(_db_model: Any) -> None: + calls["n"] += 1 + if calls["n"] == 1: + raise duplicate_error + + monkeypatch.setattr(repo, "_persist_to_database", persist) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + repo.save(execution) + assert execution.id == "new-id" + assert repo._node_execution_cache[execution.node_execution_id] is not None + + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error)) + with pytest.raises(IntegrityError): + repo.save(_execution(execution_id="id2", node_execution_id="node2")) + assert any("Non-duplicate key integrity error" in r.message for r in caplog.records) + + +def test_save_logs_and_reraises_on_unexpected_error( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom"))) + with pytest.raises(RuntimeError, match="boom"): + repo.save(_execution(execution_id="id3", node_execution_id="node3")) + assert any("Failed to save workflow node execution" in r.message for r in caplog.records) + + +def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def __init__(self) -> None: + self.where_calls = 0 + self.order_by_args: tuple[Any, ...] | None = None + + def where(self, *_args: Any) -> FakeStmt: + self.where_calls += 1 + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.order_by_args = args + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + model1 = SimpleNamespace(node_execution_id="n1") + model2 = SimpleNamespace(node_execution_id=None) + session = MagicMock() + session.scalars.return_value.all.return_value = [model1, model2] + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + order = OrderConfig(order_by=["index", "missing"], order_direction="desc") + db_models = repo.get_db_models_by_workflow_run("run", order) + assert db_models == [model1, model2] + assert repo._node_execution_cache["n1"] is model1 + assert stmt.order_by_args is not None + + +def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def where(self, *_args: Any) -> FakeStmt: + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.args = args # type: ignore[attr-defined] + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc")) + + +def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")] + monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models) + monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}") + + class FakeExecutor: + def __enter__(self) -> FakeExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def map(self, func, items, timeout: int): # type: ignore[no-untyped-def] + assert timeout == 30 + return list(map(func, items)) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor", + lambda max_workers: FakeExecutor(), + ) + + result = repo.get_by_workflow_run("run", order_config=None) + assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/schemas/test_registry.py b/api/tests/unit_tests/core/schemas/test_registry.py new file mode 100644 index 0000000000..5749e72eb0 --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_registry.py @@ -0,0 +1,137 @@ +import json +from unittest.mock import patch + +from core.schemas.registry import SchemaRegistry + + +class TestSchemaRegistry: + def test_initialization(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + registry = SchemaRegistry(str(base_dir)) + assert registry.base_dir == base_dir + assert registry.versions == {} + assert registry.metadata == {} + + def test_default_registry_singleton(self): + registry1 = SchemaRegistry.default_registry() + registry2 = SchemaRegistry.default_registry() + assert registry1 is registry2 + assert isinstance(registry1, SchemaRegistry) + + def test_load_all_versions_non_existent_dir(self, tmp_path): + base_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(base_dir)) + registry.load_all_versions() + assert registry.versions == {} + + def test_load_all_versions_filtering(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + (base_dir / "not_a_version_dir").mkdir() + (base_dir / "v1").mkdir() + (base_dir / "some_file.txt").write_text("content") + + registry = SchemaRegistry(str(base_dir)) + with patch.object(registry, "_load_version_dir") as mock_load: + registry.load_all_versions() + mock_load.assert_called_once() + assert mock_load.call_args[0][0] == "v1" + + def test_load_version_dir_filtering(self, tmp_path): + version_dir = tmp_path / "v1" + version_dir.mkdir() + (version_dir / "schema1.json").write_text("{}") + (version_dir / "not_a_schema.txt").write_text("content") + + registry = SchemaRegistry(str(tmp_path)) + with patch.object(registry, "_load_schema") as mock_load: + registry._load_version_dir("v1", version_dir) + mock_load.assert_called_once() + assert mock_load.call_args[0][1] == "schema1" + + def test_load_version_dir_non_existent(self, tmp_path): + version_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(tmp_path)) + registry._load_version_dir("v1", version_dir) + assert "v1" not in registry.versions + + def test_load_schema_success(self, tmp_path): + schema_path = tmp_path / "test.json" + schema_content = {"title": "Test Schema", "description": "A test schema"} + schema_path.write_text(json.dumps(schema_content)) + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "test", schema_path) + + assert registry.versions["v1"]["test"] == schema_content + uri = "https://dify.ai/schemas/v1/test.json" + assert registry.metadata[uri]["title"] == "Test Schema" + assert registry.metadata[uri]["version"] == "v1" + + def test_load_schema_invalid_json(self, tmp_path, caplog): + schema_path = tmp_path / "invalid.json" + schema_path.write_text("invalid json") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "invalid", schema_path) + + assert "Failed to load schema v1/invalid" in caplog.text + + def test_load_schema_os_error(self, tmp_path, caplog): + schema_path = tmp_path / "error.json" + schema_path.write_text("{}") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + + with patch("builtins.open", side_effect=OSError("Read error")): + registry._load_schema("v1", "error", schema_path) + + assert "Failed to load schema v1/error" in caplog.text + + def test_get_schema(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"type": "object"}}} + + # Valid URI + assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"} + + # Invalid URI + assert registry.get_schema("invalid-uri") is None + + # Missing version + assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None + + def test_list_versions(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v2": {}, "v1": {}} + assert registry.list_versions() == ["v1", "v2"] + + def test_list_schemas(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"b": {}, "a": {}}} + + assert registry.list_schemas("v1") == ["a", "b"] + assert registry.list_schemas("v2") == [] + + def test_get_all_schemas_for_version(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"title": "Test Label"}}} + + results = registry.get_all_schemas_for_version("v1") + assert len(results) == 1 + assert results[0]["name"] == "test" + assert results[0]["label"] == "Test Label" + assert results[0]["schema"] == {"title": "Test Label"} + + # Default label if title missing + registry.versions["v1"]["no_title"] = {} + results = registry.get_all_schemas_for_version("v1") + item = next(r for r in results if r["name"] == "no_title") + assert item["label"] == "no_title" + + # Empty if version missing + assert registry.get_all_schemas_for_version("v2") == [] diff --git a/api/tests/unit_tests/core/schemas/test_schema_manager.py b/api/tests/unit_tests/core/schemas/test_schema_manager.py new file mode 100644 index 0000000000..cb07340c6d --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_schema_manager.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock, patch + +from core.schemas.registry import SchemaRegistry +from core.schemas.schema_manager import SchemaManager + + +def test_init_with_provided_registry(): + mock_registry = MagicMock(spec=SchemaRegistry) + manager = SchemaManager(registry=mock_registry) + assert manager.registry == mock_registry + + +@patch("core.schemas.schema_manager.SchemaRegistry.default_registry") +def test_init_with_default_registry(mock_default_registry): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_default_registry.return_value = mock_registry + + manager = SchemaManager() + + mock_default_registry.assert_called_once() + assert manager.registry == mock_registry + + +def test_get_all_schema_definitions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}] + mock_registry.get_all_schemas_for_version.return_value = expected_definitions + + manager = SchemaManager(registry=mock_registry) + result = manager.get_all_schema_definitions(version="v2") + + mock_registry.get_all_schemas_for_version.assert_called_once_with("v2") + assert result == expected_definitions + + +def test_get_schema_by_name_success(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_schema = {"type": "object"} + mock_registry.get_schema.return_value = mock_schema + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("my_schema", version="v1") + + expected_uri = "https://dify.ai/schemas/v1/my_schema.json" + mock_registry.get_schema.assert_called_once_with(expected_uri) + assert result == {"name": "my_schema", "schema": mock_schema} + + +def test_get_schema_by_name_not_found(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_registry.get_schema.return_value = None + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("non_existent", version="v1") + + assert result is None + + +def test_list_available_schemas(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_schemas = ["schema1", "schema2"] + mock_registry.list_schemas.return_value = expected_schemas + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_schemas(version="v1") + + mock_registry.list_schemas.assert_called_once_with("v1") + assert result == expected_schemas + + +def test_list_available_versions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_versions = ["v1", "v2"] + mock_registry.list_versions.return_value = expected_versions + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_versions() + + mock_registry.list_versions.assert_called_once() + assert result == expected_versions diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 5ceaa08893..ae5638784c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -110,7 +110,7 @@ def test_encrypt_tool_parameters(): assert encrypted["plain"] == "x" -def test_decrypt_tool_parameters_cache_hit_and_miss(): +def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch): manager = _build_manager() with ( @@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache(): mock_delete.assert_called_once() -def test_configuration_manager_decrypt_suppresses_errors(): +def test_configuration_manager_decrypt_suppresses_errors(monkeypatch): manager = _build_manager() with ( patch.object(ToolParameterCache, "get", return_value=None), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py new file mode 100644 index 0000000000..bc00b49fba --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py @@ -0,0 +1,145 @@ +import queue +from collections.abc import Generator +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue +from dify_graph.graph_engine.worker import Worker +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent + + +def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: + fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) + + worker = Worker( + ready_queue=InMemoryReadyQueue(), + event_queue=queue.Queue(), + graph=MagicMock(), + layers=[], + ) + node = SimpleNamespace( + execution_id="exec-1", + id="node-1", + node_type=BuiltinNodeTypes.LLM, + ) + + event = worker._build_fallback_failure_event(node, RuntimeError("boom")) + + assert event.start_at == fixed_time + assert event.finished_at == fixed_time + assert event.error == "boom" + assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert event.node_run_result.error == "boom" + assert event.node_run_result.error_type == "RuntimeError" + + +def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: + start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + failure_time = start_at + timedelta(seconds=5) + captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] + + class FakeNode: + execution_id = "exec-1" + id = "node-1" + node_type = BuiltinNodeTypes.LLM + + def ensure_execution_id(self) -> str: + return self.execution_id + + def run(self) -> Generator[NodeRunStartedEvent, None, None]: + yield NodeRunStartedEvent( + id=self.execution_id, + node_id=self.id, + node_type=self.node_type, + node_title="LLM", + start_at=start_at, + ) + + worker = Worker( + ready_queue=MagicMock(), + event_queue=MagicMock(), + graph=MagicMock(nodes={"node-1": FakeNode()}), + layers=[], + ) + + worker._ready_queue.get.side_effect = ["node-1"] + + def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: + captured_events.append(event) + if len(captured_events) == 1: + raise RuntimeError("queue boom") + worker.stop() + + worker._event_queue.put.side_effect = put_side_effect + + with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + worker.run() + + fallback_event = captured_events[-1] + + assert isinstance(fallback_event, NodeRunFailedEvent) + assert fallback_event.start_at == start_at + assert fallback_event.finished_at == failure_time + assert fallback_event.error == "queue boom" + assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED + + +def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: + parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + child_start = parent_start + timedelta(seconds=3) + failure_time = parent_start + timedelta(seconds=5) + captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] + + class FakeIterationNode: + execution_id = "iteration-exec" + id = "iteration-node" + node_type = BuiltinNodeTypes.ITERATION + + def ensure_execution_id(self) -> str: + return self.execution_id + + def run(self) -> Generator[NodeRunStartedEvent, None, None]: + yield NodeRunStartedEvent( + id=self.execution_id, + node_id=self.id, + node_type=self.node_type, + node_title="Iteration", + start_at=parent_start, + ) + yield NodeRunStartedEvent( + id="child-exec", + node_id="child-node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=child_start, + in_iteration_id=self.id, + ) + + worker = Worker( + ready_queue=MagicMock(), + event_queue=MagicMock(), + graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), + layers=[], + ) + + worker._ready_queue.get.side_effect = ["iteration-node"] + + def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: + captured_events.append(event) + if len(captured_events) == 2: + raise RuntimeError("queue boom") + worker.stop() + + worker._event_queue.put.side_effect = put_side_effect + + with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + worker.run() + + fallback_event = captured_events[-1] + + assert isinstance(fallback_event, NodeRunFailedEvent) + assert fallback_event.start_at == parent_start + assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py new file mode 100644 index 0000000000..8660449032 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py @@ -0,0 +1,63 @@ +import time +from contextlib import nullcontext +from datetime import UTC, datetime + +import pytest + +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_events import NodeRunSucceededEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.nodes.iteration.iteration_node import IterationNode + + +def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: + node = IterationNode.__new__(IterationNode) + node._node_data = IterationNodeData( + title="Parallel Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration", "output"], + is_parallel=True, + parallel_nums=2, + error_handle_mode=ErrorHandleMode.TERMINATED, + ) + node._capture_execution_context = lambda: nullcontext() + node._sync_conversation_variables_from_snapshot = lambda snapshot: None + node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) + + def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): + return ( + 0.1 + (index * 0.1), + [ + NodeRunSucceededEvent( + id=f"exec-{index}", + node_id=f"llm-{index}", + node_type=BuiltinNodeTypes.LLM, + start_at=datetime.now(UTC).replace(tzinfo=None), + ), + ], + f"output-{item}", + {}, + LLMUsage.empty_usage(), + ) + + node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel + + outputs: list[object] = [] + iter_run_map: dict[str, float] = {} + usage_accumulator = [LLMUsage.empty_usage()] + + generator = node._execute_parallel_iterations( + iterator_list_value=["a", "b"], + outputs=outputs, + iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, + ) + + for _ in generator: + # Simulate a slow consumer replaying buffered events. + time.sleep(0.02) + + assert outputs == ["output-a", "output-b"] + assert iter_run_map["0"] == pytest.approx(0.1) + assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab..feb560bbc3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -5,6 +5,7 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -78,7 +79,7 @@ def sample_node_data(): type="knowledge-index", chunk_structure="general_structure", index_chunk_variable_selector=["start", "chunks"], - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, summary_index_setting=None, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py new file mode 100644 index 0000000000..acecbf4944 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -0,0 +1,272 @@ +from unittest import mock + +import pytest + +from core.model_manager import ModelInstance +from dify_graph.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage +from dify_graph.nodes.llm import llm_utils +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage +from dify_graph.nodes.llm.exc import NoPromptFoundError +from dify_graph.runtime import VariablePool + + +@pytest.fixture +def variable_pool() -> VariablePool: + pool = VariablePool.empty() + pool.add(["node1", "output"], "resolved_value") + pool.add(["node2", "text"], "hello world") + pool.add(["start", "user_input"], "dynamic_param") + return pool + + +def _fetch_prompt_messages_with_mocked_content(content): + variable_pool = VariablePool.empty() + model_instance = mock.MagicMock(spec=ModelInstance) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "dify_graph.nodes.llm.llm_utils.fetch_model_schema", + return_value=mock.MagicMock(features=[]), + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_list_messages", + return_value=[SystemPromptMessage(content=content)], + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", + return_value=[], + ), + ): + return llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context=None, + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=["END"], + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + template_renderer=None, + ) + + +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + """Non-string params pass through; string params with variables get resolved.""" + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + """When a referenced variable doesn't exist in the pool, convert_template + falls back to the raw selector path (e.g. 'nonexistent.var').""" + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + +def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): + with pytest.raises(NoPromptFoundError): + _fetch_prompt_messages_with_mocked_content( + [ + ImagePromptMessageContent( + format="url", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + + +def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items(): + prompt_messages, stop = _fetch_prompt_messages_with_mocked_content( + [ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="url", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + + assert stop == ["END"] + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain(): + prompt_messages, stop = _fetch_prompt_messages_with_mocked_content( + [ + TextPromptMessageContent(data="You are"), + TextPromptMessageContent(data=" a classifier."), + ImagePromptMessageContent( + format="url", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + + assert stop == ["END"] + assert prompt_messages == [ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are"), + TextPromptMessageContent(data=" a classifier."), + ] + ) + ] diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py new file mode 100644 index 0000000000..9aeab0409e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -0,0 +1,63 @@ +from collections.abc import Mapping + +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: + init_params = build_test_graph_init_params( + graph_config=graph_config, + user_from="account", + invoke_from="debugger", + ) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool( + system_variables=SystemVariable(user_id="user", files=[]), + user_inputs={"payload": "value"}, + ), + start_at=0.0, + ) + return init_params, runtime_state + + +def _build_node_config() -> NodeConfigDict: + return NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": TRIGGER_PLUGIN_NODE_TYPE, + "title": "Trigger Event", + "plugin_id": "plugin-id", + "provider_id": "provider-id", + "event_name": "event-name", + "subscription_id": "subscription-id", + "plugin_unique_identifier": "plugin-unique-identifier", + "event_parameters": {}, + }, + } + ) + + +def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: + init_params, runtime_state = _build_context(graph_config={}) + node = TriggerEventNode( + id="node-1", + config=_build_node_config(), + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { + "provider_id": "provider-id", + "event_name": "event-name", + "plugin_unique_identifier": "plugin-unique-identifier", + } diff --git a/api/tests/unit_tests/dify_graph/node_events/test_base.py b/api/tests/unit_tests/dify_graph/node_events/test_base.py new file mode 100644 index 0000000000..6d789abac0 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/node_events/test_base.py @@ -0,0 +1,19 @@ +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.node_events.base import NodeRunResult + + +def test_node_run_result_accepts_trigger_info_metadata() -> None: + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + metadata={ + WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { + "provider_id": "provider-id", + "event_name": "event-name", + } + }, + ) + + assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { + "provider_id": "provider-id", + "event_name": "event-name", + } diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index 248aa0b145..bf548f69cf 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -1,7 +1,11 @@ +import threading import time +from dataclasses import dataclass +from typing import cast import pytest +from libs.broadcast_channel.exc import SubscriptionClosedError from libs.broadcast_channel.redis.streams_channel import ( StreamsBroadcastChannel, StreamsTopic, @@ -22,6 +26,7 @@ class FakeStreamsRedis: self._store: dict[str, list[tuple[str, dict]]] = {} self._next_id: dict[str, int] = {} self._expire_calls: dict[str, int] = {} + self._dollar_snapshots: dict[str, int] = {} # Publisher API def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: @@ -47,7 +52,9 @@ class FakeStreamsRedis: # Find position strictly greater than last_id start_idx = 0 - if last_id != "0-0": + if last_id == "$": + start_idx = self._dollar_snapshots.setdefault(key, len(entries)) + elif last_id != "0-0": for i, (eid, _f) in enumerate(entries): if eid == last_id: start_idx = i + 1 @@ -63,6 +70,55 @@ class FakeStreamsRedis: return [(key, batch)] +class FailExpireRedis(FakeStreamsRedis): + def expire(self, key: str, seconds: int) -> None: + raise RuntimeError("expire failed") + + +class BlockingRedis: + def __init__(self) -> None: + self._release = threading.Event() + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + self._release.wait(timeout=block / 1000.0 if block else None) + return [] + + def release(self) -> None: + self._release.set() + + +@dataclass(frozen=True) +class ListenPayloadCase: + name: str + fields: object + expected_messages: list[bytes] + + +def build_listen_payload_cases() -> list[ListenPayloadCase]: + return [ + ListenPayloadCase( + name="string_payload_is_encoded", + fields={b"data": "hello"}, + expected_messages=[b"hello"], + ), + ListenPayloadCase( + name="bytearray_payload_is_converted", + fields={b"data": bytearray(b"world")}, + expected_messages=[b"world"], + ), + ListenPayloadCase( + name="non_dict_fields_are_ignored", + fields=[("data", b"ignored")], + expected_messages=[], + ), + ListenPayloadCase( + name="missing_payload_is_ignored", + fields={b"other": b"ignored"}, + expected_messages=[], + ), + ] + + @pytest.fixture def fake_redis() -> FakeStreamsRedis: return FakeStreamsRedis() @@ -94,21 +150,37 @@ class TestStreamsBroadcastChannel: # Expire called after publish assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("producer-subscriber") + + assert topic.as_producer() is topic + assert topic.as_subscriber() is topic + + def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture): + channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60) + topic = channel.topic("expire-warning") + + topic.publish(b"payload") + + assert "Failed to set expire for stream key" in caplog.text + class TestStreamsSubscription: - def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + def test_subscribe_only_receives_messages_published_after_subscription_starts( + self, + streams_channel: StreamsBroadcastChannel, + ): topic = streams_channel.topic("gamma") - # Pre-publish events before subscribing (late subscriber) - topic.publish(b"e1") - topic.publish(b"e2") + topic.publish(b"before-subscribe") sub = topic.subscribe() assert isinstance(sub, _StreamsSubscription) received: list[bytes] = [] with sub: - # Give listener thread a moment to xread - time.sleep(0.05) + assert sub.receive(timeout=0.05) is None + topic.publish(b"after-subscribe-1") + topic.publish(b"after-subscribe-2") # Drain using receive() to avoid indefinite iteration in tests for _ in range(5): msg = sub.receive(timeout=0.1) @@ -116,7 +188,7 @@ class TestStreamsSubscription: break received.append(msg) - assert received == [b"e1", b"e2"] + assert received == [b"after-subscribe-1", b"after-subscribe-2"] def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("delta") @@ -132,8 +204,6 @@ class TestStreamsSubscription: # Listener running; now close and ensure no crash sub.close() # After close, receive should raise SubscriptionClosedError - from libs.broadcast_channel.exc import SubscriptionClosedError - with pytest.raises(SubscriptionClosedError): sub.receive() @@ -143,3 +213,141 @@ class TestStreamsSubscription: topic.publish(b"payload") # No expire recorded when retention is disabled assert fake_redis._expire_calls.get("stream:zeta") is None + + @pytest.mark.parametrize( + ("case"), + build_listen_payload_cases(), + ids=lambda case: cast(ListenPayloadCase, case).name, + ) + def test_listener_normalizes_supported_payloads_and_ignores_unsupported_shapes(self, case: ListenPayloadCase): + class OneShotRedis: + def __init__(self, fields: object) -> None: + self._fields = fields + self._calls = 0 + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + self._calls += 1 + if self._calls == 1: + key = next(iter(streams)) + return [(key, [("1-0", self._fields)])] + subscription._closed.set() + return [] + + subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape") + subscription._listen() + + received: list[bytes] = [] + while not subscription._queue.empty(): + item = subscription._queue.get_nowait() + if item is subscription._SENTINEL: + break + received.append(bytes(item)) + + assert received == case.expected_messages + assert subscription._last_id == "1-0" + + def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("iter") + subscription = topic.subscribe() + iterator = iter(subscription) + + def publish_later() -> None: + time.sleep(0.05) + topic.publish(b"iter-message") + + publisher = threading.Thread(target=publish_later, daemon=True) + publisher.start() + + assert next(iterator) == b"iter-message" + + subscription.close() + publisher.join(timeout=1) + with pytest.raises(StopIteration): + next(iterator) + + def test_receive_with_none_timeout_blocks_until_message_arrives(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("blocking") + subscription = topic.subscribe() + + def publish_later() -> None: + time.sleep(0.05) + topic.publish(b"blocking-message") + + publisher = threading.Thread(target=publish_later, daemon=True) + publisher.start() + + try: + assert subscription.receive(timeout=None) == b"blocking-message" + finally: + subscription.close() + publisher.join(timeout=1) + + def test_receive_raises_when_queue_contains_close_sentinel(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:sentinel") + subscription._listener = threading.current_thread() + subscription._queue.put_nowait(subscription._SENTINEL) + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.01) + + def test_close_before_listener_starts_is_a_noop(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:not-started") + + subscription.close() + + assert subscription._listener is None + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.01) + + def test_start_if_needed_returns_immediately_for_closed_subscription(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed") + subscription._closed.set() + + subscription._start_if_needed() + + assert subscription._listener is None + + def test_iterator_skips_none_results_and_keeps_polling(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:iterator-none") + items = iter([None, b"event"]) + + subscription._start_if_needed = lambda: None # type: ignore[method-assign] + + def fake_receive(timeout: float | None = 0.1) -> bytes | None: + value = next(items) + if value is not None: + subscription._closed.set() + return value + + subscription.receive = fake_receive # type: ignore[method-assign] + + assert next(iter(subscription)) == b"event" + + def test_close_logs_warning_when_listener_does_not_stop_in_time( + self, + caplog: pytest.LogCaptureFixture, + ): + blocking_redis = BlockingRedis() + subscription = _StreamsSubscription(blocking_redis, "stream:slow-close") + + subscription._start_if_needed() + listener = subscription._listener + assert listener is not None + + original_join = listener.join + original_is_alive = listener.is_alive + + def delayed_join(timeout: float | None = None) -> None: + original_join(0.01) + + listener.join = delayed_join # type: ignore[method-assign] + listener.is_alive = lambda: True # type: ignore[method-assign] + + try: + subscription.close() + assert "did not stop within timeout" in caplog.text + finally: + listener.join = original_join # type: ignore[method-assign] + listener.is_alive = original_is_alive # type: ignore[method-assign] + blocking_redis.release() + original_join(timeout=1) diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index a94ba0c00b..8613d89215 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -130,6 +130,25 @@ class TestLoginRequired: assert result == "Synced content" setup_app.ensure_sync.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) + def test_patched_current_user_without_login_manager(self, app: Flask): + """Test that patched current_user bypasses login manager bootstrapping.""" + + @login_required + def protected_view(): + return "Protected content" + + mock_user = MockUser("test_user", is_authenticated=True) + mock_proxy = MagicMock() + mock_proxy._get_current_object.return_value = mock_user + + with app.test_request_context(): + app.ensure_sync = lambda func: func + with patch("libs.login.current_user", mock_proxy): + result = protected_view() + assert result == "Protected content" + assert g._login_user == mock_user + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_1_compatibility(self, setup_app: Flask): """Test Flask 1.x compatibility without ensure_sync.""" diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index bc7880ccc8..3918e8ee4b 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -95,13 +95,11 @@ class TestGitHubOAuth(BaseOAuthTest): ], "primary@example.com", ), - # User with no emails - fallback to noreply - ({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"), - # User with only secondary email - fallback to noreply + # User with private email (null email and name from API) ( - {"id": 12345, "login": "testuser", "name": "Test User"}, - [{"email": "secondary@example.com", "primary": False}], - "12345+testuser@users.noreply.github.com", + {"id": 12345, "login": "testuser", "name": None, "email": None}, + [{"email": "primary@example.com", "primary": True}], + "primary@example.com", ), ], ) @@ -118,9 +116,54 @@ class TestGitHubOAuth(BaseOAuthTest): user_info = oauth.get_user_info("test_token") assert user_info.id == str(user_data["id"]) - assert user_info.name == user_data["name"] + assert user_info.name == (user_data["name"] or "") assert user_info.email == expected_email + @pytest.mark.parametrize( + ("user_data", "email_data"), + [ + # User with no emails + ({"id": 12345, "login": "testuser", "name": "Test User"}, []), + # User with only secondary email + ( + {"id": 12345, "login": "testuser", "name": "Test User"}, + [{"email": "secondary@example.com", "primary": False}], + ), + # User with private email and no primary in emails endpoint + ( + {"id": 12345, "login": "testuser", "name": None, "email": None}, + [], + ), + ], + ) + @patch("httpx.get", autospec=True) + def test_should_raise_error_when_no_primary_email(self, mock_get, oauth, user_data, email_data): + user_response = MagicMock() + user_response.json.return_value = user_data + + email_response = MagicMock() + email_response.json.return_value = email_data + + mock_get.side_effect = [user_response, email_response] + + with pytest.raises(ValueError, match="Keep my email addresses private"): + oauth.get_user_info("test_token") + + @patch("httpx.get", autospec=True) + def test_should_raise_error_when_email_endpoint_fails(self, mock_get, oauth): + user_response = MagicMock() + user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"} + + email_response = MagicMock() + email_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Forbidden", request=MagicMock(), response=MagicMock() + ) + + mock_get.side_effect = [user_response, email_response] + + with pytest.raises(ValueError, match="Keep my email addresses private"): + oauth.get_user_info("test_token") + @patch("httpx.get", autospec=True) def test_should_handle_network_errors(self, mock_get, oauth): mock_get.side_effect = httpx.RequestError("Network error") diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 329fe554ea..b6577daac8 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -16,6 +16,7 @@ from uuid import uuid4 import pytest +from models.enums import ConversationFromSource from models.model import ( App, AppAnnotationHitHistory, @@ -324,7 +325,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=from_end_user_id, ) @@ -345,7 +346,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), ) conversation._inputs = inputs @@ -364,7 +365,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), ) inputs = {"query": "Hello", "context": "test"} @@ -383,7 +384,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), summary="Test summary", ) @@ -402,7 +403,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), summary=None, ) @@ -425,7 +426,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), override_model_configs='{"model": "gpt-4"}', ) @@ -446,7 +447,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=from_end_user_id, dialogue_count=5, ) @@ -487,7 +488,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) # Assert @@ -511,7 +512,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) message._inputs = inputs @@ -533,7 +534,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) inputs = {"query": "Hello", "context": "test"} @@ -555,7 +556,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, override_model_configs='{"model": "gpt-4"}', ) @@ -578,7 +579,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, message_metadata=json.dumps(metadata), ) @@ -600,7 +601,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, message_metadata=None, ) @@ -627,7 +628,7 @@ class TestMessageModel: answer_unit_price=Decimal("0.0002"), total_price=Decimal("0.0003"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, status="normal", ) message.id = str(uuid4()) @@ -988,7 +989,7 @@ class TestModelIntegration: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), ) conversation.id = conversation_id @@ -1003,7 +1004,7 @@ class TestModelIntegration: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) message.id = message_id @@ -1064,7 +1065,7 @@ class TestModelIntegration: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) message.id = message_id @@ -1158,12 +1159,12 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = str(uuid4()) # Mock the database query to return no messages - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: mock_scalars.return_value.all.return_value = [] # Act @@ -1183,12 +1184,12 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id # Mock the database query to return no messages with workflow_run_id - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: mock_scalars.return_value.all.return_value = [] # Act @@ -1215,7 +1216,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1273,7 +1274,7 @@ class TestConversationStatusCount: return mock_result # Act & Assert - with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): + with patch("models.model.db.session.scalars", side_effect=mock_scalars): result = conversation.status_count # Verify only 2 database queries were made (not N+1) @@ -1307,7 +1308,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1336,7 +1337,7 @@ class TestConversationStatusCount: return mock_result # Act - with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): + with patch("models.model.db.session.scalars", side_effect=mock_scalars): result = conversation.status_count # Assert - query should include app_id filter @@ -1361,7 +1362,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1381,7 +1382,7 @@ class TestConversationStatusCount: ), ] - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: # Mock the messages query def mock_scalars_side_effect(query): mock_result = MagicMock() @@ -1418,7 +1419,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1437,7 +1438,7 @@ class TestConversationStatusCount: ), ] - with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: + with patch("models.model.db.session.scalars") as mock_scalars: def mock_scalars_side_effect(query): mock_result = MagicMock() diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 98dd07907a..6c8a91129b 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -15,6 +15,7 @@ from datetime import UTC, datetime from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import ( AppDatasetJoin, ChildChunk, @@ -67,14 +68,14 @@ class TestDatasetModelValidation: data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), description="Test description", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) # Assert assert dataset.description == "Test description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -86,21 +87,21 @@ class TestDatasetModelValidation: name="High Quality Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) dataset_economy = Dataset( tenant_id=str(uuid4()), name="Economy Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) # Assert - assert dataset_high_quality.indexing_technique == "high_quality" - assert dataset_economy.indexing_technique == "economy" - assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST - assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST + assert dataset_high_quality.indexing_technique == IndexTechniqueType.HIGH_QUALITY + assert dataset_economy.indexing_technique == IndexTechniqueType.ECONOMY + assert IndexTechniqueType.HIGH_QUALITY in Dataset.INDEXING_TECHNIQUE_LIST + assert IndexTechniqueType.ECONOMY in Dataset.INDEXING_TECHNIQUE_LIST def test_dataset_provider_validation(self): """Test dataset provider values.""" @@ -983,7 +984,7 @@ class TestModelIntegration: name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) dataset.id = dataset_id @@ -1019,7 +1020,7 @@ class TestModelIntegration: assert document.dataset_id == dataset_id assert segment.dataset_id == dataset_id assert segment.document_id == document_id - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert document.word_count == 100 assert segment.status == SegmentStatus.COMPLETED diff --git a/api/tests/unit_tests/models/test_enums_creator_user_role.py b/api/tests/unit_tests/models/test_enums_creator_user_role.py new file mode 100644 index 0000000000..6317166fdc --- /dev/null +++ b/api/tests/unit_tests/models/test_enums_creator_user_role.py @@ -0,0 +1,19 @@ +import pytest + +from models.enums import CreatorUserRole + + +def test_creator_user_role_missing_maps_hyphen_to_enum(): + # given an alias with hyphen + value = "end-user" + + # when converting to enum (invokes StrEnum._missing_ override) + role = CreatorUserRole(value) + + # then it should map to END_USER + assert role is CreatorUserRole.END_USER + + +def test_creator_user_role_missing_raises_for_unknown(): + with pytest.raises(ValueError): + CreatorUserRole("unknown") diff --git a/api/tests/unit_tests/models/test_provider_models.py b/api/tests/unit_tests/models/test_provider_models.py index ec84a61c8e..f628e54a4d 100644 --- a/api/tests/unit_tests/models/test_provider_models.py +++ b/api/tests/unit_tests/models/test_provider_models.py @@ -19,6 +19,7 @@ from uuid import uuid4 import pytest +from models.enums import CredentialSourceType, PaymentStatus from models.provider import ( LoadBalancingModelConfig, Provider, @@ -158,7 +159,7 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == provider_name - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_used == 0 @@ -172,10 +173,10 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="anthropic", - provider_type="system", + provider_type=ProviderType.SYSTEM, is_valid=True, credential_id=credential_id, - quota_type="paid", + quota_type=ProviderQuotaType.PAID, quota_limit=10000, quota_used=500, ) @@ -183,10 +184,10 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == "anthropic" - assert provider.provider_type == "system" + assert provider.provider_type == ProviderType.SYSTEM assert provider.is_valid is True assert provider.credential_id == credential_id - assert provider.quota_type == "paid" + assert provider.quota_type == ProviderQuotaType.PAID assert provider.quota_limit == 10000 assert provider.quota_used == 500 @@ -199,7 +200,7 @@ class TestProviderModel: ) # Assert - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_type == "" assert provider.quota_limit is None @@ -213,7 +214,7 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="openai", - provider_type="custom", + provider_type=ProviderType.CUSTOM, ) # Act @@ -253,7 +254,7 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, is_valid=True, ) @@ -266,13 +267,13 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - quota_type="trial", + quota_type=ProviderQuotaType.TRIAL, quota_limit=1000, quota_used=250, ) # Assert - assert provider.quota_type == "trial" + assert provider.quota_type == ProviderQuotaType.TRIAL assert provider.quota_limit == 1000 assert provider.quota_used == 250 remaining = provider.quota_limit - provider.quota_used @@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=tenant_id, provider_name="openai", - preferred_provider_type="custom", + preferred_provider_type=ProviderType.CUSTOM, ) # Assert assert preferred.tenant_id == tenant_id assert preferred.provider_name == "openai" - assert preferred.preferred_provider_type == "custom" + assert preferred.preferred_provider_type == ProviderType.CUSTOM def test_tenant_preferred_provider_system_type(self): """Test tenant preferred provider with system type.""" @@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=str(uuid4()), provider_name="anthropic", - preferred_provider_type="system", + preferred_provider_type=ProviderType.SYSTEM, ) # Assert - assert preferred.preferred_provider_type == "system" + assert preferred.preferred_provider_type == ProviderType.SYSTEM class TestProviderOrder: @@ -470,7 +471,7 @@ class TestProviderOrder: quantity=1, currency=None, total_amount=None, - payment_status="wait_pay", + payment_status=PaymentStatus.WAIT_PAY, paid_at=None, pay_failed_at=None, refunded_at=None, @@ -481,7 +482,7 @@ class TestProviderOrder: assert order.provider_name == "openai" assert order.account_id == account_id assert order.payment_product_id == "prod_123" - assert order.payment_status == "wait_pay" + assert order.payment_status == PaymentStatus.WAIT_PAY assert order.quantity == 1 def test_provider_order_with_payment_details(self): @@ -502,7 +503,7 @@ class TestProviderOrder: quantity=5, currency="USD", total_amount=9999, - payment_status="paid", + payment_status=PaymentStatus.PAID, paid_at=paid_time, pay_failed_at=None, refunded_at=None, @@ -514,7 +515,7 @@ class TestProviderOrder: assert order.quantity == 5 assert order.currency == "USD" assert order.total_amount == 9999 - assert order.payment_status == "paid" + assert order.payment_status == PaymentStatus.PAID assert order.paid_at == paid_time def test_provider_order_payment_statuses(self): @@ -536,23 +537,23 @@ class TestProviderOrder: } # Act & Assert - Wait pay status - wait_order = ProviderOrder(**base_params, payment_status="wait_pay") - assert wait_order.payment_status == "wait_pay" + wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY) + assert wait_order.payment_status == PaymentStatus.WAIT_PAY # Act & Assert - Paid status - paid_order = ProviderOrder(**base_params, payment_status="paid") - assert paid_order.payment_status == "paid" + paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID) + assert paid_order.payment_status == PaymentStatus.PAID # Act & Assert - Failed status failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)} - failed_order = ProviderOrder(**failed_params, payment_status="failed") - assert failed_order.payment_status == "failed" + failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED) + assert failed_order.payment_status == PaymentStatus.FAILED assert failed_order.pay_failed_at is not None # Act & Assert - Refunded status refunded_params = {**base_params, "refunded_at": datetime.now(UTC)} - refunded_order = ProviderOrder(**refunded_params, payment_status="refunded") - assert refunded_order.payment_status == "refunded" + refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED) + assert refunded_order.payment_status == PaymentStatus.REFUNDED assert refunded_order.refunded_at is not None @@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig: name="Secondary API Key", encrypted_config='{"api_key": "encrypted_value"}', credential_id=credential_id, - credential_source_type="custom", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) # Assert assert config.encrypted_config == '{"api_key": "encrypted_value"}' assert config.credential_id == credential_id - assert config.credential_source_type == "custom" + assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL def test_load_balancing_config_disabled(self): """Test disabled load balancing config.""" diff --git a/api/tests/unit_tests/models/test_tool_models.py b/api/tests/unit_tests/models/test_tool_models.py index 1a75eb9a01..8e3c4da904 100644 --- a/api/tests/unit_tests/models/test_tool_models.py +++ b/api/tests/unit_tests/models/test_tool_models.py @@ -12,7 +12,7 @@ This test suite covers: import json from uuid import uuid4 -from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType from models.tools import ( ApiToolProvider, BuiltinToolProvider, @@ -238,7 +238,7 @@ class TestApiToolProviderValidation: name=provider_name, icon='{"type": "emoji", "value": "🔧"}', schema=schema, - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Custom API for testing", tools_str=json.dumps(tools), credentials_str=json.dumps(credentials), @@ -249,7 +249,7 @@ class TestApiToolProviderValidation: assert api_provider.user_id == user_id assert api_provider.name == provider_name assert api_provider.schema == schema - assert api_provider.schema_type_str == "openapi" + assert api_provider.schema_type_str == ApiProviderSchemaType.OPENAPI assert api_provider.description == "Custom API for testing" def test_api_tool_provider_schema_type_property(self): @@ -261,7 +261,7 @@ class TestApiToolProviderValidation: name="Test API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -314,7 +314,7 @@ class TestApiToolProviderValidation: name="Weather API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Weather API", tools_str=json.dumps(tools_data), credentials_str="{}", @@ -343,7 +343,7 @@ class TestApiToolProviderValidation: name="Secure API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Secure API", tools_str="[]", credentials_str=json.dumps(credentials_data), @@ -369,7 +369,7 @@ class TestApiToolProviderValidation: name="Privacy API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with privacy policy", tools_str="[]", credentials_str="{}", @@ -391,7 +391,7 @@ class TestApiToolProviderValidation: name="Disclaimer API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with disclaimer", tools_str="[]", credentials_str="{}", @@ -410,7 +410,7 @@ class TestApiToolProviderValidation: name="Default API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API", tools_str="[]", credentials_str="{}", @@ -432,7 +432,7 @@ class TestApiToolProviderValidation: name=provider_name, icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Unique API", tools_str="[]", credentials_str="{}", @@ -454,7 +454,7 @@ class TestApiToolProviderValidation: name="Public API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Public API with no auth", tools_str="[]", credentials_str=json.dumps(credentials), @@ -479,7 +479,7 @@ class TestApiToolProviderValidation: name="Query Auth API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with query auth", tools_str="[]", credentials_str=json.dumps(credentials), @@ -631,7 +631,7 @@ class TestToolLabelBinding: """Test creating a tool label binding.""" # Arrange tool_id = "google.search" - tool_type = "builtin" + tool_type = ToolProviderType.BUILT_IN label_name = "search" # Act @@ -655,7 +655,7 @@ class TestToolLabelBinding: # Act label_binding = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name=label_name, ) @@ -667,7 +667,7 @@ class TestToolLabelBinding: """Test multiple labels can be bound to the same tool.""" # Arrange tool_id = "google.search" - tool_type = "builtin" + tool_type = ToolProviderType.BUILT_IN # Act binding1 = ToolLabelBinding( @@ -688,7 +688,7 @@ class TestToolLabelBinding: def test_tool_label_binding_different_tool_types(self): """Test label bindings for different tool types.""" # Arrange - tool_types = ["builtin", "api", "workflow"] + tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW] # Act & Assert for tool_type in tool_types: @@ -741,7 +741,7 @@ class TestCredentialStorage: name="Test API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str=json.dumps(credentials), @@ -788,7 +788,7 @@ class TestCredentialStorage: name="Update Test", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str=json.dumps(original_credentials), @@ -897,7 +897,7 @@ class TestToolProviderRelationships: name="User API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -931,7 +931,7 @@ class TestToolProviderRelationships: name="Custom API 1", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -951,12 +951,12 @@ class TestToolProviderRelationships: # Act binding1 = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name="search", ) binding2 = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name="web", ) diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index f3b72aa128..ef29b26a7a 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -4,12 +4,18 @@ from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE +from core.helper import encrypter from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.file.models import File from dify_graph.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from dify_graph.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import ( + Workflow, + WorkflowDraftVariable, + WorkflowNodeExecutionModel, + is_system_variable_editable, +) def test_environment_variables(): @@ -144,6 +150,36 @@ def test_to_dict(): assert workflow_dict["environment_variables"][1]["value"] == "text" +def test_normalize_environment_variable_mappings_converts_full_mask_to_hidden_value(): + normalized = Workflow.normalize_environment_variable_mappings( + [ + { + "id": str(uuid4()), + "name": "secret", + "value": encrypter.full_mask_token(), + "value_type": "secret", + } + ] + ) + + assert normalized[0]["value"] == HIDDEN_VALUE + + +def test_normalize_environment_variable_mappings_keeps_hidden_value(): + normalized = Workflow.normalize_environment_variable_mappings( + [ + { + "id": str(uuid4()), + "name": "secret", + "value": HIDDEN_VALUE, + "value_type": "secret", + } + ] + ) + + assert normalized[0]["value"] == HIDDEN_VALUE + + class TestWorkflowNodeExecution: def test_execution_metadata_dict(self): node_exec = WorkflowNodeExecutionModel() diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py deleted file mode 100644 index 3707ed90be..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Unit tests for non-SQL helper logic in workflow run repository.""" - -import secrets -from datetime import UTC, datetime -from unittest.mock import Mock, patch - -import pytest - -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType -from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason -from repositories.sqlalchemy_api_workflow_run_repository import ( - _build_human_input_required_reason, - _PrivateWorkflowPauseEntity, -) - - -@pytest.fixture -def sample_workflow_pause() -> Mock: - """Create a sample WorkflowPause model.""" - pause = Mock(spec=WorkflowPauseModel) - pause.id = "pause-123" - pause.workflow_id = "workflow-123" - pause.workflow_run_id = "workflow-run-123" - pause.state_object_key = "workflow-state-123.json" - pause.resumed_at = None - pause.created_at = datetime.now(UTC) - return pause - - -class TestPrivateWorkflowPauseEntity: - """Test _PrivateWorkflowPauseEntity class.""" - - def test_properties(self, sample_workflow_pause: Mock) -> None: - """Test entity properties.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - - # Assert - assert entity.id == sample_workflow_pause.id - assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id - assert entity.resumed_at == sample_workflow_pause.resumed_at - - def test_get_state(self, sample_workflow_pause: Mock) -> None: - """Test getting state from storage.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - expected_state = b'{"test": "state"}' - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - mock_storage.load.return_value = expected_state - - # Act - result = entity.get_state() - - # Assert - assert result == expected_state - mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key) - - def test_get_state_caching(self, sample_workflow_pause: Mock) -> None: - """Test state caching in get_state method.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - expected_state = b'{"test": "state"}' - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - mock_storage.load.return_value = expected_state - - # Act - result1 = entity.get_state() - result2 = entity.get_state() - - # Assert - assert result1 == expected_state - assert result2 == expected_state - mock_storage.load.assert_called_once() - - -class TestBuildHumanInputRequiredReason: - """Test helper that builds HumanInputRequired pause reasons.""" - - def test_prefers_backstage_token_when_available(self) -> None: - """Use backstage token when multiple recipient types may exist.""" - # Arrange - expiration_time = datetime.now(UTC) - form_definition = FormDefinition( - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "Alice"}, - node_title="Ask Name", - display_in_ui=True, - ) - form_model = HumanInputForm( - id="form-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_run_id="run-1", - node_id="node-1", - form_definition=form_definition.model_dump_json(), - rendered_content="rendered", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - reason_model = WorkflowPauseReason( - pause_id="pause-1", - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, - form_id="form-1", - node_id="node-1", - message="", - ) - access_token = secrets.token_urlsafe(8) - backstage_recipient = HumanInputFormRecipient( - form_id="form-1", - delivery_id="delivery-1", - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - - # Act - reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) - - # Assert - assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token - assert reason.node_title == "Ask Name" - assert reason.form_content == "content" - assert reason.inputs[0].output_variable_name == "name" - assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py deleted file mode 100644 index 8daf91c538..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ /dev/null @@ -1,180 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from datetime import UTC, datetime, timedelta - -from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain -from core.entities.execution_extra_content import HumanInputFormSubmissionData -from dify_graph.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from models.execution_extra_content import HumanInputContent as HumanInputContentModel -from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository - - -class _FakeScalarResult: - def __init__(self, values: Sequence[HumanInputContentModel]): - self._values = list(values) - - def all(self) -> list[HumanInputContentModel]: - return list(self._values) - - -class _FakeSession: - def __init__(self, values: Sequence[Sequence[object]]): - self._values = list(values) - - def scalars(self, _stmt): - if not self._values: - return _FakeScalarResult([]) - return _FakeScalarResult(self._values.pop(0)) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -@dataclass -class _FakeSessionMaker: - session: _FakeSession - - def __call__(self) -> _FakeSession: - return self.session - - -def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm: - expiration_time = datetime.now(UTC) + timedelta(days=1) - definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id=action_id, title=action_title)], - rendered_content="rendered", - expiration_time=expiration_time, - node_title="Approval", - display_in_ui=True, - ) - form = HumanInputForm( - id=f"form-{action_id}", - tenant_id="tenant-id", - app_id="app-id", - workflow_run_id="workflow-run", - node_id="node-id", - form_definition=definition.model_dump_json(), - rendered_content=rendered_content, - status=HumanInputFormStatus.SUBMITTED, - expiration_time=expiration_time, - ) - form.selected_action_id = action_id - return form - - -def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel: - form = _build_form( - action_id=action_id, - action_title=action_title, - rendered_content=f"Rendered {action_title}", - ) - content = HumanInputContentModel( - id=f"content-{message_id}", - form_id=form.id, - message_id=message_id, - workflow_run_id=form.workflow_run_id, - ) - content.form = form - return content - - -def test_get_by_message_ids_groups_contents_by_message() -> None: - message_ids = ["msg-1", "msg-2"] - contents = [_build_content("msg-1", "approve", "Approve")] - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []])) - ) - - result = repository.get_by_message_ids(message_ids) - - assert len(result) == 2 - assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [ - HumanInputContentDomain( - workflow_run_id="workflow-run", - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id="node-id", - node_title="Approval", - rendered_content="Rendered Approve", - action_id="approve", - action_text="Approve", - ), - ).model_dump(mode="json", exclude_none=True) - ] - assert result[1] == [] - - -def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None: - expiration_time = datetime.now(UTC) + timedelta(days=1) - definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "John"}, - node_title="Approval", - display_in_ui=True, - ) - form = HumanInputForm( - id="form-1", - tenant_id="tenant-id", - app_id="app-id", - workflow_run_id="workflow-run", - node_id="node-id", - form_definition=definition.model_dump_json(), - rendered_content="Rendered block", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - content = HumanInputContentModel( - id="content-msg-1", - form_id=form.id, - message_id="msg-1", - workflow_run_id=form.workflow_run_id, - ) - content.form = form - - recipient = HumanInputFormRecipient( - form_id=form.id, - delivery_id="delivery-1", - recipient_type=RecipientType.CONSOLE, - recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(), - access_token="token-1", - ) - - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]])) - ) - - result = repository.get_by_message_ids(["msg-1"]) - - assert len(result) == 1 - assert len(result[0]) == 1 - domain_content = result[0][0] - assert domain_content.submitted is False - assert domain_content.workflow_run_id == "workflow-run" - assert domain_content.form_definition is not None - assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp()) - assert domain_content.form_definition is not None - form_definition = domain_content.form_definition - assert form_definition.form_id == "form-1" - assert form_definition.node_id == "node-id" - assert form_definition.node_title == "Approval" - assert form_definition.form_content == "Rendered block" - assert form_definition.display_in_ui is True - assert form_definition.form_token == "token-1" - assert form_definition.resolved_default_values == {"name": "John"} - assert form_definition.expiration_time == int(form.expiration_time.timestamp()) diff --git a/api/tests/unit_tests/repositories/test_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_workflow_run_repository.py deleted file mode 100644 index 8f47f0df48..0000000000 --- a/api/tests/unit_tests/repositories/test_workflow_run_repository.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Unit tests for workflow run repository with status filter.""" - -import uuid -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import sessionmaker - -from models import WorkflowRun, WorkflowRunTriggeredFrom -from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository - - -class TestDifyAPISQLAlchemyWorkflowRunRepository: - """Test workflow run repository with status filtering.""" - - @pytest.fixture - def mock_session_maker(self): - """Create a mock session maker.""" - return MagicMock(spec=sessionmaker) - - @pytest.fixture - def repository(self, mock_session_maker): - """Create repository instance with mock session.""" - return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker) - - def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker): - """Test getting paginated workflow runs without status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)] - mock_session.scalars.return_value.all.return_value = mock_runs - - # Act - result = repository.get_paginated_workflow_runs( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - limit=20, - last_id=None, - status=None, - ) - - # Assert - assert len(result.data) == 3 - assert result.limit == 20 - assert result.has_more is False - - def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker): - """Test getting paginated workflow runs with status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)] - mock_session.scalars.return_value.all.return_value = mock_runs - - # Act - result = repository.get_paginated_workflow_runs( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - limit=20, - last_id=None, - status="succeeded", - ) - - # Assert - assert len(result.data) == 2 - assert all(run.status == "succeeded" for run in result.data) - - def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker): - """Test getting workflow runs count without status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the GROUP BY query results - mock_results = [ - ("succeeded", 5), - ("failed", 2), - ("running", 1), - ] - mock_session.execute.return_value.all.return_value = mock_results - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status=None, - ) - - # Assert - assert result["total"] == 8 - assert result["succeeded"] == 5 - assert result["failed"] == 2 - assert result["running"] == 1 - assert result["stopped"] == 0 - assert result["partial-succeeded"] == 0 - - def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker): - """Test getting workflow runs count with status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the count query for succeeded status - mock_session.scalar.return_value = 5 - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status="succeeded", - ) - - # Assert - assert result["total"] == 5 - assert result["succeeded"] == 5 - assert result["running"] == 0 - assert result["failed"] == 0 - assert result["stopped"] == 0 - assert result["partial-succeeded"] == 0 - - def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker): - """Test that invalid status is still counted in total but not in any specific status.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock count query returning 0 for invalid status - mock_session.scalar.return_value = 0 - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status="invalid_status", - ) - - # Assert - assert result["total"] == 0 - assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"]) - - def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker): - """Test getting workflow runs count with time range filter verifies SQL query construction.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the GROUP BY query results - mock_results = [ - ("succeeded", 3), - ("running", 2), - ] - mock_session.execute.return_value.all.return_value = mock_results - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status=None, - time_range="1d", - ) - - # Assert results - assert result["total"] == 5 - assert result["succeeded"] == 3 - assert result["running"] == 2 - assert result["failed"] == 0 - - # Verify that execute was called (which means GROUP BY query was used) - assert mock_session.execute.called, "execute should have been called for GROUP BY query" - - # Verify SQL query includes time filter by checking the statement - call_args = mock_session.execute.call_args - assert call_args is not None, "execute should have been called with a statement" - - # The first argument should be the SQL statement - stmt = call_args[0][0] - # Convert to string to inspect the query - query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) - - # Verify the query includes created_at filter - # The query should have a WHERE clause with created_at comparison - assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( - "Query should include created_at filter for time range" - ) - - def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker): - """Test getting workflow runs count with both status and time range filters verifies SQL query.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the count query for running status within time range - mock_session.scalar.return_value = 2 - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status="running", - time_range="1d", - ) - - # Assert results - assert result["total"] == 2 - assert result["running"] == 2 - assert result["succeeded"] == 0 - assert result["failed"] == 0 - - # Verify that scalar was called (which means COUNT query was used) - assert mock_session.scalar.called, "scalar should have been called for count query" - - # Verify SQL query includes both status and time filter - call_args = mock_session.scalar.call_args - assert call_args is not None, "scalar should have been called with a statement" - - # The first argument should be the SQL statement - stmt = call_args[0][0] - # Convert to string to inspect the query - query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) - - # Verify the query includes both filters - assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( - "Query should include created_at filter for time range" - ) - assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), ( - "Query should include status filter" - ) diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py index b5d91ef3fb..388504c07f 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py @@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase): class TestApiKeyAuthBase: def test_should_store_credentials_on_init(self): """Test that credentials are properly stored during initialization""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) assert auth.credentials == credentials def test_should_not_instantiate_abstract_class(self): """Test that ApiKeyAuthBase cannot be instantiated directly""" - credentials = {"api_key": "test_key"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} with pytest.raises(TypeError) as exc_info: ApiKeyAuthBase(credentials) @@ -29,7 +29,7 @@ class TestApiKeyAuthBase: def test_should_allow_subclass_implementation(self): """Test that subclasses can properly implement the abstract method""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) # Should not raise any exception diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py index 60af6e20c2..b1f7cf24f3 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py @@ -58,7 +58,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) result = factory.validate_credentials() # Assert @@ -75,7 +75,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act & Assert - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) with pytest.raises(Exception) as exc_info: factory.validate_credentials() assert str(exc_info.value) == "Authentication error" diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index c805dd98e2..424ac18870 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory: name: str = "Test Dataset", description: str = "Test description", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str | None = "openai", embedding_model: str | None = "text-embedding-ada-002", collection_binding_id: str | None = "binding-123", @@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_knowledge_configuration_mock( chunk_structure: str = "tree", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", keyword_number: int = 10, @@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: # Assert assert dataset.chunk_structure == "list" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == "binding-123" @@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", # Existing structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", # Different structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) mock_session.merge.return_value = dataset @@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( dataset_id="dataset-123", runtime_mode="rag_pipeline", - indexing_technique="high_quality", # Current technique + indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - indexing_technique="economy", # Trying to change to economy + indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy ) mock_session.merge.return_value = dataset diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 6829691507..49fdc5cc9b 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,6 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService @@ -153,7 +154,7 @@ class DocumentValidationTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", **kwargs, @@ -188,8 +189,8 @@ class DocumentValidationTestDataFactory: def create_knowledge_config_mock( data_source: DataSource | None = None, process_rule: ProcessRule | None = None, - doc_form: str = "text_model", - indexing_technique: str = "high_quality", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, **kwargs, ) -> Mock: """ @@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm: - Validation logic works correctly """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "text_model" + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None) - doc_form = "text_model" + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm: - Error type is correct """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "table_model" # Different form + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card") - doc_form = "text_model" # Different form + doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -480,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting: - No errors are raised """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) # Act (should not raise) DatasetService.check_dataset_model_setting(dataset) @@ -502,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="invalid-model", ) @@ -532,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py index 27df4556bc..6511385000 100644 --- a/api/tests/unit_tests/services/plugin/test_oauth_service.py +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -13,6 +13,10 @@ import pytest from services.plugin.oauth_service import OAuthProxyService +def _oauth_proxy_setex_calls(redis_client) -> list: + return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")] + + class TestCreateProxyContext: def test_stores_context_in_redis_with_ttl(self): context_id = OAuthProxyService.create_proxy_context( @@ -22,8 +26,9 @@ class TestCreateProxyContext: assert context_id # non-empty UUID string from extensions.ext_redis import redis_client - redis_client.setex.assert_called_once() - call_args = redis_client.setex.call_args + oauth_calls = _oauth_proxy_setex_calls(redis_client) + assert len(oauth_calls) == 1 + call_args = oauth_calls[0] key = call_args[0][0] ttl = call_args[0][1] stored_data = json.loads(call_args[0][2]) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py deleted file mode 100644 index 9fe153c153..0000000000 --- a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,216 +0,0 @@ -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy.orm import Session - -from models.workflow import WorkflowRun -from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult - - -class TestArchivedWorkflowRunDeletion: - @pytest.fixture - def mock_db(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - @pytest.fixture - def mock_sessionmaker(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - mock_session = MagicMock(spec=Session) - mock_sm.return_value.return_value.__enter__.return_value = mock_session - yield mock_sm, mock_session - - @pytest.fixture - def mock_workflow_run_repo(self): - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - yield mock_repo - - def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - tenant_id = "tenant-456" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_run.tenant_id = tenant_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [run_id] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) - mock_delete_run.return_value = expected_result - - result = deletion.delete_by_run_id(run_id) - - assert result == expected_result - mock_session.get.assert_called_once_with(WorkflowRun, run_id) - mock_repo.get_archived_run_ids.assert_called_once() - mock_delete_run.assert_called_once_with(mock_run) - - def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - mock_session.get.return_value = None - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo"): - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "not found" in result.error - assert result.run_id == run_id - - def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [] - - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "is not archived" in result.error - - def test_delete_batch(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - deletion = ArchivedWorkflowRunDeletion() - - mock_run1 = MagicMock(spec=WorkflowRun) - mock_run1.id = "run-1" - mock_run2 = MagicMock(spec=WorkflowRun) - mock_run2.id = "run-2" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - mock_delete_run.side_effect = [ - DeleteResult(run_id="run-1", tenant_id="t1", success=True), - DeleteResult(run_id="run-2", tenant_id="t1", success=True), - ] - - results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) - - assert len(results) == 2 - assert results[0].run_id == "run-1" - assert results[1].run_id == "run-2" - assert mock_delete_run.call_count == 2 - - def test_delete_run_dry_run(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=True) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.run_id == "run-123" - - def test_delete_run_success(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.deleted_counts == {"workflow_runs": 1} - - def test_delete_run_exception(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.side_effect = Exception("Database error") - - result = deletion._delete_run(mock_run) - - assert result.success is False - assert result.error == "Database error" - - def test_delete_trigger_logs(self): - mock_session = MagicMock(spec=Session) - run_ids = ["run-1", "run-2"] - - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - mock_repo_cls.return_value = mock_repo - mock_repo.delete_by_run_ids.return_value = 5 - - count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) - - assert count == 5 - mock_repo_cls.assert_called_once_with(mock_session) - mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) - - def test_delete_node_executions(self): - mock_session = MagicMock(spec=Session) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-1" - runs = [mock_run] - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - mock_repo.delete_by_runs.return_value = (1, 2) - - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) - - assert result == (1, 2) - mock_create_repo.assert_called_once() - mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) - - def test_get_workflow_run_repo(self, mock_db): - deletion = ArchivedWorkflowRunDeletion() - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - - # First call - repo1 = deletion._get_workflow_run_repo() - assert repo1 == mock_repo - assert deletion.workflow_run_repo == mock_repo - - # Second call (should return cached) - repo2 = deletion._get_workflow_run_repo() - assert repo2 == mock_repo - mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b5..14af7f7119 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,8 +2,10 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import SegmentType from services.dataset_service import SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -77,7 +79,7 @@ class SegmentTestDataFactory: chunk.word_count = word_count chunk.index_node_id = f"node-{chunk_id}" chunk.index_node_hash = "hash-123" - chunk.type = "automatic" + chunk.type = SegmentType.AUTOMATIC chunk.created_by = "user-123" chunk.updated_by = None chunk.updated_at = None @@ -90,7 +92,7 @@ class SegmentTestDataFactory: document_id: str = "doc-123", dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, word_count: int = 100, **kwargs, ) -> Mock: @@ -109,7 +111,7 @@ class SegmentTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model: str = "text-embedding-ada-002", embedding_model_provider: str = "openai", **kwargs, @@ -161,7 +163,7 @@ class TestSegmentServiceCreateSegment: """Test successful creation of a segment.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test", "segment"]} mock_query = MagicMock() @@ -209,8 +211,8 @@ class TestSegmentServiceCreateSegment: def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): """Test creation of segment with QA model (requires answer).""" # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} mock_query = MagicMock() @@ -245,7 +247,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with high quality indexing technique.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -287,7 +289,7 @@ class TestSegmentServiceCreateSegment: """Test segment creation when vector indexing fails.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -340,7 +342,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment @@ -428,8 +430,8 @@ class TestSegmentServiceUpdateSegment: """Test update segment with QA model (includes answer).""" # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py deleted file mode 100644 index a6bc79e82b..0000000000 --- a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -Unit tests for services.advanced_prompt_template_service -""" - -import copy - -from core.prompt.prompt_templates.advanced_prompt_templates import ( - BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_CONTEXT, - CHAT_APP_CHAT_PROMPT_CONFIG, - CHAT_APP_COMPLETION_PROMPT_CONFIG, - COMPLETION_APP_CHAT_PROMPT_CONFIG, - COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - CONTEXT, -) -from models.model import AppMode -from services.advanced_prompt_template_service import AdvancedPromptTemplateService - - -class TestAdvancedPromptTemplateService: - """Test suite for AdvancedPromptTemplateService.""" - - def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: - """Test baichuan model names use baichuan context prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "chat", - "model_name": "Baichuan2-13B", - "has_context": "true", - } - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - - def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: - """Test non-baichuan model names use common prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "completion", - "model_name": "gpt-4", - "has_context": "false", - } - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result == original_config - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: - """Test invalid app mode returns empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "chat" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} - - def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: - """Test context is prepended for completion prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: - """Test context is prepended for chat prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) - assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: - """Test chat prompt remains unchanged when has_context is false.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") - - # Assert - assert result == original_config - assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: - """Test completion app mode with completion model returns completion prompt.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") - - # Assert - assert result == original_config - assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: - """Test invalid model mode returns empty dict.""" - # Arrange - app_mode = AppMode.CHAT - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") - - # Assert - assert result == {} - - def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps completion prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] - - # Act - result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"] == original_text - - def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps chat prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] - - # Act - result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text - - def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: - """Test baichuan chat/completion returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: - """Test baichuan completion/chat returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: - """Test baichuan completion/completion prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: - """Test baichuan chat/chat prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: - """Test invalid baichuan mode combinations return empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py deleted file mode 100644 index 7ce3d7ef7b..0000000000 --- a/api/tests/unit_tests/services/test_agent_service.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -Unit tests for services.agent_service -""" - -from collections.abc import Callable -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -import pytz - -from core.plugin.impl.exc import PluginDaemonClientSideError -from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought -from services.agent_service import AgentService - - -def _make_current_user_account(timezone: str = "UTC") -> Account: - account = Account(name="Test User", email="test@example.com") - account.timezone = timezone - return account - - -def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: - app_model = MagicMock(spec=App) - app_model.id = "app-123" - app_model.tenant_id = "tenant-123" - app_model.app_model_config = app_model_config - return app_model - - -def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: - conversation = MagicMock(spec=Conversation) - conversation.id = "conv-123" - conversation.app_id = "app-123" - conversation.from_end_user_id = from_end_user_id - conversation.from_account_id = from_account_id - return conversation - - -def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: - message = MagicMock(spec=Message) - message.id = "msg-123" - message.conversation_id = "conv-123" - message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - message.provider_response_latency = 1.23 - message.answer_tokens = 4 - message.message_tokens = 6 - message.agent_thoughts = agent_thoughts - message.message_files = ["file-a.txt"] - return message - - -def _make_agent_thought() -> MagicMock: - agent_thought = MagicMock(spec=MessageAgentThought) - agent_thought.tokens = 3 - agent_thought.tool_input = "raw-input" - agent_thought.observation = "raw-output" - agent_thought.thought = "thinking" - agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - agent_thought.files = [] - agent_thought.tools = ["tool_a", "dataset_tool"] - agent_thought.tool_labels = {"tool_a": "Tool A"} - agent_thought.tool_meta = { - "tool_a": { - "tool_config": { - "tool_provider_type": "custom", - "tool_provider": "provider-1", - }, - "tool_parameters": {"param": "value"}, - "time_cost": 2.5, - }, - "dataset_tool": { - "tool_config": { - "tool_provider_type": "dataset-retrieval", - "tool_provider": "dataset-provider", - } - }, - } - agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} - agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} - return agent_thought - - -def _build_query_side_effect( - conversation: Conversation | None, - message: Message | None, - executor: EndUser | Account | None, -) -> Callable[..., MagicMock]: - def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: - query = MagicMock() - query.where.return_value = query - if any(arg is Conversation for arg in args): - query.first.return_value = conversation - elif any(arg is Message for arg in args): - query.first.return_value = message - elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): - query.first.return_value = executor - return query - - return _query_side_effect - - -class TestAgentServiceGetAgentLogs: - """Test suite for AgentService.get_agent_logs.""" - - def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: - """Test missing conversation raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - with patch("services.agent_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") - - def test_get_agent_logs_should_raise_when_message_missing(self) -> None: - """Test missing message raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - with patch("services.agent_service.db") as mock_db: - conversation_query = MagicMock() - conversation_query.where.return_value = conversation_query - conversation_query.first.return_value = conversation - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = None - - mock_db.session.query.side_effect = [conversation_query, message_query] - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") - - def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: - """Test missing app model config raises ValueError.""" - # Arrange - app_model = _make_app_model(None) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: - """Test missing agent config raises ValueError.""" - # Arrange - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: - """Test agent logs returned for end-user executor with tool icons.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - executor = MagicMock(spec=EndUser) - executor.name = "End User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_tool = MagicMock() - agent_tool.tool_name = "tool_a" - agent_tool.provider_type = "custom" - agent_tool.provider_id = "provider-2" - agent_config = MagicMock() - agent_config.tools = [agent_tool] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, - patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - mock_get_icon.side_effect = [None, "icon-a"] - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["status"] == "success" - assert result["meta"]["executor"] == "End User" - assert result["meta"]["total_tokens"] == 10 - assert result["meta"]["agent_mode"] == "react" - assert result["meta"]["iterations"] == 1 - assert result["files"] == ["file-a.txt"] - assert len(result["iterations"]) == 1 - tool_calls = result["iterations"][0]["tool_calls"] - assert tool_calls[0]["tool_name"] == "tool_a" - assert tool_calls[0]["tool_icon"] == "icon-a" - assert tool_calls[1]["tool_name"] == "dataset_tool" - assert tool_calls[1]["tool_icon"] == "" - mock_convert.assert_called_once() - - def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: - """Test agent logs fall back to account executor when end user is missing.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") - executor = MagicMock(spec=Account) - executor.name = "Account User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Account User" - - def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: - """Test unknown executor and missing tool details fall back to defaults.""" - # Arrange - agent_thought = _make_agent_thought() - agent_thought.tool_labels = {} - agent_thought.tool_inputs_dict = {} - agent_thought.tool_outputs_dict = None - agent_thought.tool_meta = {"tool_a": {"error": "failed"}} - agent_thought.tools = ["tool_a"] - - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Unknown" - assert result["meta"]["agent_mode"] == "react" - tool_call = result["iterations"][0]["tool_calls"][0] - assert tool_call["status"] == "error" - assert tool_call["error"] == "failed" - assert tool_call["tool_label"] == "tool_a" - assert tool_call["tool_input"] == {} - assert tool_call["tool_output"] == {} - assert tool_call["time_cost"] == 0 - assert tool_call["tool_parameters"] == {} - assert tool_call["tool_icon"] is None - - -class TestAgentServiceProviders: - """Test suite for AgentService provider methods.""" - - def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: - """Test list_agent_providers delegates to PluginAgentClient.""" - # Arrange - tenant_id = "tenant-1" - expected = [{"name": "provider"}] - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_providers.return_value = expected - - # Act - result = AgentService.list_agent_providers("user-1", tenant_id) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) - - def test_get_agent_provider_should_return_provider_when_successful(self) -> None: - """Test get_agent_provider returns provider when successful.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - expected = {"name": provider_name} - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.return_value = expected - - # Act - result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) - - def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: - """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( - "plugin error" - ) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py deleted file mode 100644 index 7f4b5fdaa3..0000000000 --- a/api/tests/unit_tests/services/test_api_based_extension_service.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -Comprehensive unit tests for services/api_based_extension_service.py - -Covers: -- APIBasedExtensionService.get_all_by_tenant_id -- APIBasedExtensionService.save -- APIBasedExtensionService.delete -- APIBasedExtensionService.get_with_tenant_id -- APIBasedExtensionService._validation (new record & existing record branches) -- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from services.api_based_extension_service import APIBasedExtensionService - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_extension( - *, - id_: str | None = None, - tenant_id: str = "tenant-001", - name: str = "my-ext", - api_endpoint: str = "https://example.com/hook", - api_key: str = "secret-key-123", -) -> MagicMock: - """Return a lightweight mock that mimics APIBasedExtension.""" - ext = MagicMock() - ext.id = id_ - ext.tenant_id = tenant_id - ext.name = name - ext.api_endpoint = api_endpoint - ext.api_key = api_key - return ext - - -# --------------------------------------------------------------------------- -# Tests: get_all_by_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetAllByTenantId: - """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): - """Each api_key is decrypted and the list is returned.""" - ext1 = _make_extension(id_="id-1", api_key="enc-key-1") - ext2 = _make_extension(id_="id-2", api_key="enc-key-2") - - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ - ext1, - ext2, - ] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [ext1, ext2] - assert ext1.api_key == "decrypted-key" - assert ext2.api_key == "decrypted-key" - assert mock_decrypt.call_count == 2 - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): - """Returns an empty list gracefully when no records exist.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [] - mock_decrypt.assert_not_called() - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): - """Verifies the DB is queried with the supplied tenant_id.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") - - mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") - - -# --------------------------------------------------------------------------- -# Tests: save -# --------------------------------------------------------------------------- - - -class TestSave: - """Tests for APIBasedExtensionService.save.""" - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation") - def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): - """Happy path: validation passes, key is encrypted, record is added and committed.""" - ext = _make_extension(id_=None, api_key="plain-key-123") - - result = APIBasedExtensionService.save(ext) - - mock_validation.assert_called_once_with(ext) - mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") - assert ext.api_key == "encrypted-key" - mock_db.session.add.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - assert result is ext - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) - def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): - """If _validation raises, save should propagate the error without touching the DB.""" - ext = _make_extension(name="") - - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(ext) - - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -# --------------------------------------------------------------------------- -# Tests: delete -# --------------------------------------------------------------------------- - - -class TestDelete: - """Tests for APIBasedExtensionService.delete.""" - - @patch("services.api_based_extension_service.db") - def test_delete_removes_record_and_commits(self, mock_db): - """delete() must call session.delete with the extension and then commit.""" - ext = _make_extension(id_="delete-me") - - APIBasedExtensionService.delete(ext) - - mock_db.session.delete.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests: get_with_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetWithTenantId: - """Tests for APIBasedExtensionService.get_with_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): - """Found extension has its api_key decrypted before being returned.""" - ext = _make_extension(id_="ext-123", api_key="enc-key") - - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext - - result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") - - assert result is ext - assert ext.api_key == "decrypted-key" - mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") - - @patch("services.api_based_extension_service.db") - def test_raises_value_error_when_not_found(self, mock_db): - """Raises ValueError when no matching extension exists.""" - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None - - with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): - """Verifies both tenant_id and extension id are used in the query.""" - ext = _make_extension(id_="ext-abc") - chain = mock_db.session.query.return_value - chain.filter_by.return_value.filter_by.return_value.first.return_value = ext - - APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") - - # First filter_by call uses tenant_id - chain.filter_by.assert_called_once_with(tenant_id="tenant-002") - # Second filter_by call uses id - chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") - - -# --------------------------------------------------------------------------- -# Tests: _validation (new record — id is falsy) -# --------------------------------------------------------------------------- - - -class TestValidationNewRecord: - """Tests for _validation() with a brand-new record (no id).""" - - def _build_mock_db(self, name_exists: bool = False): - mock_db = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() if name_exists else None - ) - return mock_db - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_new_extension_passes(self, mock_db, mock_ping): - """A new record with all valid fields should pass without exceptions.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_empty(self, mock_db): - """Empty name raises ValueError.""" - ext = _make_extension(id_=None, name="") - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_none(self, mock_db): - """None name raises ValueError.""" - ext = _make_extension(id_=None, name=None) - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_already_exists_for_new_record(self, mock_db): - """A new record whose name already exists raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() - ) - ext = _make_extension(id_=None, name="duplicate-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_empty(self, mock_db): - """Empty api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint="") - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_none(self, mock_db): - """None api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint=None) - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_empty(self, mock_db): - """Empty api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="") - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_none(self, mock_db): - """None api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key=None) - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_too_short(self, mock_db): - """api_key shorter than 5 characters raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="abc") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_exactly_four_chars(self, mock_db): - """api_key with exactly 4 characters raises ValueError (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="1234") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): - """api_key with exactly 5 characters should pass (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="12345") - - # Should not raise - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _validation (existing record — id is truthy) -# --------------------------------------------------------------------------- - - -class TestValidationExistingRecord: - """Tests for _validation() with an existing record (id is set).""" - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_existing_extension_passes(self, mock_db, mock_ping): - """An existing record whose name is unique (excluding self) should pass.""" - # .where(...).first() → None means no *other* record has that name - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = None - ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): - """Existing record cannot use a name already owned by a different record.""" - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = MagicMock() - ext = _make_extension(id_="existing-id", name="taken-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _ping_connection -# --------------------------------------------------------------------------- - - -class TestPingConnection: - """Tests for APIBasedExtensionService._ping_connection.""" - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_successful_ping_returns_pong(self, mock_requestor_class): - """When the endpoint returns {"result": "pong"}, no exception is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") - # Should not raise - APIBasedExtensionService._ping_connection(ext) - - mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): - """When the response is not {"result": "pong"}, a ValueError is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "error"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_network_exception_wraps_in_value_error(self, mock_requestor_class): - """Any exception raised during request is wrapped in a ValueError.""" - mock_client = MagicMock() - mock_client.request.side_effect = ConnectionError("network failure") - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: network failure"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): - """Exception raised by the requestor constructor itself is wrapped.""" - mock_requestor_class.side_effect = RuntimeError("bad config") - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: bad config"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_missing_result_key_raises_value_error(self, mock_requestor_class): - """A response dict without a 'result' key does not equal 'pong' → raises.""" - mock_client = MagicMock() - mock_client.request.return_value = {} # no 'result' key - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_uses_ping_extension_point(self, mock_requestor_class): - """The PING extension point is passed to the client.request call.""" - from models.api_based_extension import APIBasedExtensionPoint - - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - APIBasedExtensionService._ping_connection(ext) - - call_kwargs = mock_client.request.call_args - assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING - assert call_kwargs.kwargs["params"] == {} diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index 4f7d184046..239e51119c 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch def test_import_app_pending_stores_import_info_in_redis(): service = AppDslService(MagicMock()) + app_dsl_service.redis_client.setex.reset_mock() result = service.import_app( account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, @@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch): created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + app_dsl_service.redis_client.delete.reset_mock() result = service.confirm_import(import_id="import-1", account=_account_mock()) assert result.status == ImportStatus.COMPLETED assert result.app_id == "confirmed-app" - app_dsl_service.redis_client.delete.assert_called_once() + app_dsl_service.redis_client.delete.assert_called_once_with( + f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1" + ) def test_confirm_import_exception_returns_failed(monkeypatch): diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py deleted file mode 100644 index bff8dc92c6..0000000000 --- a/api/tests/unit_tests/services/test_app_service.py +++ /dev/null @@ -1,609 +0,0 @@ -"""Unit tests for services.app_service.""" - -import json -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock, patch - -import pytest - -from core.errors.error import ProviderTokenNotInitError -from models import Account, Tenant -from models.model import App, AppMode -from services.app_service import AppService - - -@pytest.fixture -def service() -> AppService: - """Provide AppService instance.""" - return AppService() - - -@pytest.fixture -def account() -> Account: - """Create account object for create_app tests.""" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - result = Account(name="Account User", email="account@example.com") - result.id = "acc-1" - result._current_tenant = tenant - return result - - -@pytest.fixture -def default_args() -> dict: - """Create default create_app args.""" - return { - "name": "Test App", - "mode": AppMode.CHAT.value, - "icon": "🤖", - "icon_background": "#FFFFFF", - } - - -@pytest.fixture -def app_template() -> dict: - """Create basic app template for create_app tests.""" - return { - AppMode.CHAT: { - "app": {}, - "model_config": { - "model": { - "provider": "provider-a", - "name": "model-a", - "mode": "chat", - "completion_params": {}, - } - }, - } - } - - -def _make_current_user() -> Account: - user = Account(name="Tester", email="tester@example.com") - user.id = "user-1" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - user._current_tenant = tenant - return user - - -class TestAppServicePagination: - """Test suite for get_paginate_apps.""" - - def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: - """Test pagination returns None when tag filter has no targets.""" - # Arrange - args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} - - with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is None - - def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: - """Test pagination delegates to db.paginate when filters are valid.""" - # Arrange - args = { - "mode": "workflow", - "is_created_by_me": True, - "name": "My_App%", - "tag_ids": ["tag-1"], - "page": 2, - "limit": 10, - } - expected_pagination = MagicMock() - - with ( - patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), - patch("libs.helper.escape_like_pattern", return_value="escaped"), - patch("services.app_service.db") as mock_db, - ): - mock_db.paginate.return_value = expected_pagination - - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is expected_pagination - mock_db.paginate.assert_called_once() - - -class TestAppServiceCreate: - """Test suite for create_app.""" - - def test_create_app_should_create_with_matching_default_model( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app uses matching default model and persists app config.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - model_instance = SimpleNamespace( - model_name="model-a", - provider="provider-a", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) - mock_config.BILLING_ENABLED = True - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - assert app_instance.app_model_config_id == "cfg-1" - mock_db.session.add.assert_any_call(app_instance) - mock_db.session.add.assert_any_call(app_model_config) - assert mock_db.session.flush.call_count == 2 - mock_db.session.commit.assert_called_once() - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - - def test_create_app_should_raise_when_model_schema_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app raises ValueError when non-matching model has no schema.""" - # Arrange - app_instance = SimpleNamespace(id="app-1") - model_instance = SimpleNamespace( - model_name="model-b", - provider="provider-b", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - model_instance.model_type_instance.get_model_schema.return_value = None - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - - # Act & Assert - with pytest.raises(ValueError, match="model schema not found"): - service.create_app("tenant-1", default_args, account) - mock_db.session.commit.assert_not_called() - - def test_create_app_should_fallback_to_default_provider_when_model_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app falls back to provider/model name when no default model instance is available.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) - mock_config.BILLING_ENABLED = False - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") - - def test_create_app_should_log_and_fallback_on_unexpected_model_error( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test unexpected model manager errors are logged and fallback provider is used.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db"), - patch("services.app_service.app_was_created"), - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), - ), - patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), - patch("services.app_service.logger") as mock_logger, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = RuntimeError("boom") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_logger.exception.assert_called_once() - - -class TestAppServiceGetAndUpdate: - """Test suite for app retrieval and update methods.""" - - def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: - """Test get_app returns original app for non-agent modes.""" - # Arrange - app = MagicMock() - app.mode = AppMode.CHAT - app.is_agent = False - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: - """Test get_app returns app when agent mode has no model config.""" - # Arrange - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = None - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: - """Test get_app decrypts and masks secret tool parameters.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - manager = MagicMock() - manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} - manager.mask_tool_parameters.return_value = {"secret": "***"} - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), - patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - assert tool["tool_parameters"] == {"secret": "***"} - assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} - - def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: - """Test get_app logs and continues when masking fails.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), - patch("services.app_service.logger") as mock_logger, - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - mock_logger.exception.assert_called_once() - - def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: - """Test update methods set fields and commit changes.""" - # Arrange - app = cast( - App, - SimpleNamespace( - name="old", - description="old", - icon_type="emoji", - icon="a", - icon_background="#111", - enable_site=True, - enable_api=True, - ), - ) - args = { - "name": "new", - "description": "new-desc", - "icon_type": "image", - "icon": "new-icon", - "icon_background": "#222", - "use_icon_as_answer_icon": True, - "max_active_requests": 5, - } - user = SimpleNamespace(id="user-1") - - with ( - patch("services.app_service.current_user", user), - patch("services.app_service.db") as mock_db, - patch("services.app_service.naive_utc_now", return_value="now"), - ): - # Act - updated = service.update_app(app, args) - renamed = service.update_app_name(app, "rename") - iconed = service.update_app_icon(app, "icon-2", "#333") - site_same = service.update_app_site_status(app, app.enable_site) - api_same = service.update_app_api_status(app, app.enable_api) - site_changed = service.update_app_site_status(app, False) - api_changed = service.update_app_api_status(app, False) - - # Assert - assert updated is app - assert renamed is app - assert iconed is app - assert site_same is app - assert api_same is app - assert site_changed is app - assert api_changed is app - assert mock_db.session.commit.call_count >= 5 - - -class TestAppServiceDeleteAndMeta: - """Test suite for delete and metadata methods.""" - - def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: - """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" - # Arrange - app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) - - with ( - patch("services.app_service.db") as mock_db, - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), - ), - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch( - "services.app_service.dify_config", - new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), - ), - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.remove_app_and_related_data_task") as mock_task, - ): - # Act - service.delete_app(app) - - # Assert - mock_db.session.delete.assert_called_once_with(app) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") - - def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: - """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" - # Arrange - workflow = SimpleNamespace( - graph_dict={ - "nodes": [ - { - "data": { - "type": "tool", - "provider_type": "builtin", - "provider_id": "builtin-provider", - "tool_name": "tool_builtin", - } - }, - { - "data": { - "type": "tool", - "provider_type": "api", - "provider_id": "api-provider-id", - "tool_name": "tool_api", - } - }, - ] - } - ) - app = cast( - App, - SimpleNamespace( - mode=AppMode.WORKFLOW.value, - workflow=workflow, - app_model_config=None, - tenant_id="tenant-1", - icon_type="emoji", - icon_background="#fff", - ), - ) - - provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = provider - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") - assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} - - def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: - """Test get_app_meta falls back to default icon when API provider lookup fails.""" - # Arrange - app_model_config = SimpleNamespace( - agent_mode_dict={ - "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] - } - ) - app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} - - def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: - """Test get_app_meta returns empty metadata when workflow/model config is absent.""" - # Arrange - workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) - chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) - - # Act - workflow_meta = service.get_app_meta(workflow_app) - chat_meta = service.get_app_meta(chat_app) - - # Assert - assert workflow_meta == {"tool_icons": {}} - assert chat_meta == {"tool_icons": {}} - - -class TestAppServiceCodeLookup: - """Test suite for app code lookup methods.""" - - def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: - """Test get_app_code_by_id raises when site is missing.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_code_by_id("app-1") - - def test_get_app_code_by_id_should_return_code(self) -> None: - """Test get_app_code_by_id returns site code.""" - # Arrange - site = SimpleNamespace(code="code-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_code_by_id("app-1") - - # Assert - assert result == "code-1" - - def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: - """Test get_app_id_by_code raises when code does not exist.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_id_by_code("missing") - - def test_get_app_id_by_code_should_return_app_id(self) -> None: - """Test get_app_id_by_code returns linked app id.""" - # Arrange - site = SimpleNamespace(app_id="app-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_id_by_code("code-1") - - # Assert - assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_attachment_service.py b/api/tests/unit_tests/services/test_attachment_service.py deleted file mode 100644 index 88be20bc41..0000000000 --- a/api/tests/unit_tests/services/test_attachment_service.py +++ /dev/null @@ -1,73 +0,0 @@ -import base64 -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from werkzeug.exceptions import NotFound - -import services.attachment_service as attachment_service_module -from models.model import UploadFile -from services.attachment_service import AttachmentService - - -class TestAttachmentService: - def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self): - """Test that AttachmentService keeps the provided sessionmaker instance.""" - session_factory = sessionmaker() - - service = AttachmentService(session_factory=session_factory) - - assert service._session_maker is session_factory - - def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self): - """Test that AttachmentService builds a sessionmaker bound to the provided engine.""" - engine = create_engine("sqlite:///:memory:") - - service = AttachmentService(session_factory=engine) - session = service._session_maker() - try: - assert session.bind == engine - finally: - session.close() - engine.dispose() - - @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) - def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory): - """Test that invalid session_factory types are rejected.""" - with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): - AttachmentService(session_factory=invalid_session_factory) - - def test_should_return_base64_encoded_blob_when_file_exists(self): - """Test that existing files are loaded from storage and returned as base64.""" - service = AttachmentService(session_factory=sessionmaker()) - upload_file = MagicMock(spec=UploadFile) - upload_file.key = "upload-file-key" - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = upload_file - service._session_maker = MagicMock(return_value=session) - - with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: - result = service.get_file_base64("file-123") - - assert result == base64.b64encode(b"binary-content").decode() - service._session_maker.assert_called_once_with(expire_on_commit=False) - session.query.assert_called_once_with(UploadFile) - mock_load.assert_called_once_with("upload-file-key") - - def test_should_raise_not_found_when_file_does_not_exist(self): - """Test that missing files raise NotFound and never call storage.""" - service = AttachmentService(session_factory=sessionmaker()) - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None - service._session_maker = MagicMock(return_value=session) - - with patch.object(attachment_service_module.storage, "load_once") as mock_load: - with pytest.raises(NotFound, match="File not found"): - service.get_file_base64("missing-file") - - service._session_maker.assert_called_once_with(expire_on_commit=False) - session.query.assert_called_once_with(UploadFile) - mock_load.assert_not_called() diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d67469105..35b288cf7c 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -405,7 +405,7 @@ class TestAudioServiceTTS: voice="en-US-Neural", ) - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" @@ -549,7 +549,7 @@ class TestAudioServiceTTS: with pytest.raises(ValueError, match="Text is required"): AudioService.transcript_tts(app_model=app, text=None) - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory): """Test that TTS returns None for invalid message ID format.""" # Arrange @@ -564,7 +564,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory): """Test that TTS returns None when message doesn't exist.""" # Arrange @@ -585,7 +585,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory): """Test that TTS returns None when message answer is empty.""" # Arrange diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index eecb3c7672..316381f0ca 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations: # Assert assert result == {} + def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request): + """Test bulk plan retrieval converts string expiration_date to int.""" + # Arrange + tenant_ids = ["tenant-1"] + mock_send_request.return_value = { + "data": { + "tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"}, + } + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert "tenant-1" in result + assert isinstance(result["tenant-1"]["expiration_date"], int) + assert result["tenant-1"]["expiration_date"] == 1735689600 + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" # Arrange diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 75551531a2..35157790ca 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -15,6 +15,7 @@ from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable +from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -350,7 +351,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_account_id=user.id, from_source="console" + from_account_id=user.id, from_source=ConversationFromSource.CONSOLE ) mock_query = mock_db_session.query.return_value @@ -374,7 +375,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_end_user_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_end_user_id=user.id, from_source="api" + from_end_user_id=user.id, from_source=ConversationFromSource.API ) mock_query = mock_db_session.query.return_value @@ -1111,7 +1112,7 @@ class TestConversationServiceEdgeCases: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_source="api", from_end_user_id="user-123" + from_source=ConversationFromSource.API, from_end_user_id="user-123" ) mock_session.scalars.return_value.all.return_value = [conversation] @@ -1143,7 +1144,7 @@ class TestConversationServiceEdgeCases: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session conversation = ConversationServiceTestDataFactory.create_conversation_mock( - from_source="console", from_account_id="account-123" + from_source=ConversationFromSource.CONSOLE, from_account_id="account-123" ) mock_session.scalars.return_value.all.return_value = [conversation] diff --git a/api/tests/unit_tests/services/test_conversation_variable_updater.py b/api/tests/unit_tests/services/test_conversation_variable_updater.py deleted file mode 100644 index 20f7caa78e..0000000000 --- a/api/tests/unit_tests/services/test_conversation_variable_updater.py +++ /dev/null @@ -1,75 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from dify_graph.variables import StringVariable -from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater - - -class TestConversationVariableUpdater: - def test_should_update_conversation_variable_data_and_commit(self): - """Test update persists serialized variable data when the row exists.""" - conversation_id = "conv-123" - variable = StringVariable( - id="var-123", - name="topic", - value="new value", - ) - expected_json = variable.model_dump_json() - - row = SimpleNamespace(data="old value") - session = MagicMock() - session.scalar.return_value = row - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - session_maker = MagicMock(return_value=session_context) - updater = ConversationVariableUpdater(session_maker) - - updater.update(conversation_id=conversation_id, variable=variable) - - session_maker.assert_called_once_with() - session.scalar.assert_called_once() - stmt = session.scalar.call_args.args[0] - compiled_params = stmt.compile().params - assert variable.id in compiled_params.values() - assert conversation_id in compiled_params.values() - assert row.data == expected_json - session.commit.assert_called_once() - - def test_should_raise_not_found_error_when_conversation_variable_missing(self): - """Test update raises ConversationVariableNotFoundError when no matching row exists.""" - conversation_id = "conv-404" - variable = StringVariable( - id="var-404", - name="topic", - value="value", - ) - - session = MagicMock() - session.scalar.return_value = None - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - session_maker = MagicMock(return_value=session_context) - updater = ConversationVariableUpdater(session_maker) - - with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): - updater.update(conversation_id=conversation_id, variable=variable) - - session.commit.assert_not_called() - - def test_should_do_nothing_when_flush_is_called(self): - """Test flush currently behaves as a no-op and returns None.""" - session_maker = MagicMock() - updater = ConversationVariableUpdater(session_maker) - - result = updater.flush() - - assert result is None - session_maker.assert_not_called() diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py deleted file mode 100644 index 9ef314cb9e..0000000000 --- a/api/tests/unit_tests/services/test_credit_pool_service.py +++ /dev/null @@ -1,157 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest - -import services.credit_pool_service as credit_pool_service_module -from core.errors.error import QuotaExceededError -from models import TenantCreditPool -from services.credit_pool_service import CreditPoolService - - -@pytest.fixture -def mock_credit_deduction_setup(): - """Fixture providing common setup for credit deduction tests.""" - pool = SimpleNamespace(remaining_credits=50) - fake_engine = MagicMock() - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool) - mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine)) - mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context) - - return { - "pool": pool, - "fake_engine": fake_engine, - "session": session, - "session_context": session_context, - "patches": (mock_get_pool, mock_db, mock_session), - } - - -class TestCreditPoolService: - def test_should_create_default_pool_with_trial_type_and_configured_quota(self): - """Test create_default_pool persists a trial pool using configured hosted credits.""" - tenant_id = "tenant-123" - hosted_pool_credits = 5000 - - with ( - patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits), - patch.object(credit_pool_service_module, "db") as mock_db, - ): - pool = CreditPoolService.create_default_pool(tenant_id) - - assert isinstance(pool, TenantCreditPool) - assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" - assert pool.quota_limit == hosted_pool_credits - assert pool.quota_used == 0 - mock_db.session.add.assert_called_once_with(pool) - mock_db.session.commit.assert_called_once() - - def test_should_return_first_pool_from_query_when_get_pool_called(self): - """Test get_pool queries by tenant and pool_type and returns first result.""" - tenant_id = "tenant-123" - pool_type = "enterprise" - expected_pool = MagicMock(spec=TenantCreditPool) - - with patch.object(credit_pool_service_module, "db") as mock_db: - query = mock_db.session.query.return_value - filtered_query = query.filter_by.return_value - filtered_query.first.return_value = expected_pool - - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type) - - assert result == expected_pool - mock_db.session.query.assert_called_once_with(TenantCreditPool) - query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type) - filtered_query.first.assert_called_once() - - def test_should_return_false_when_pool_not_found_in_check_credits_available(self): - """Test check_credits_available returns False when tenant has no pool.""" - with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool: - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10) - - assert result is False - mock_get_pool.assert_called_once_with("tenant-123", "trial") - - def test_should_return_true_when_remaining_credits_cover_required_amount(self): - """Test check_credits_available returns True when remaining credits are sufficient.""" - pool = SimpleNamespace(remaining_credits=100) - - with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool: - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) - - assert result is True - mock_get_pool.assert_called_once_with("tenant-123", "trial") - - def test_should_return_false_when_remaining_credits_are_insufficient(self): - """Test check_credits_available returns False when required credits exceed remaining credits.""" - pool = SimpleNamespace(remaining_credits=30) - - with patch.object(CreditPoolService, "get_pool", return_value=pool): - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) - - assert result is False - - def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self): - """Test check_and_deduct_credits raises when tenant credit pool does not exist.""" - with patch.object(CreditPoolService, "get_pool", return_value=None): - with pytest.raises(QuotaExceededError, match="Credit pool not found"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self): - """Test check_and_deduct_credits raises when remaining credits are zero or negative.""" - pool = SimpleNamespace(remaining_credits=0) - - with patch.object(CreditPoolService, "get_pool", return_value=pool): - with pytest.raises(QuotaExceededError, match="No credits remaining"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup): - """Test check_and_deduct_credits updates quota_used by the actual deducted amount.""" - tenant_id = "tenant-123" - pool_type = "trial" - credits_required = 200 - remaining_credits = 120 - expected_deducted_credits = 120 - - mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits - patches = mock_credit_deduction_setup["patches"] - session = mock_credit_deduction_setup["session"] - - with patches[0], patches[1], patches[2]: - result = CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=credits_required, - pool_type=pool_type, - ) - - assert result == expected_deducted_credits - session.execute.assert_called_once() - session.commit.assert_called_once() - - stmt = session.execute.call_args.args[0] - compiled_params = stmt.compile().params - assert tenant_id in compiled_params.values() - assert pool_type in compiled_params.values() - assert expected_deducted_credits in compiled_params.values() - - def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup): - """Test check_and_deduct_credits translates DB update failures to QuotaExceededError.""" - mock_credit_deduction_setup["pool"].remaining_credits = 50 - mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure") - session = mock_credit_deduction_setup["session"] - - patches = mock_credit_deduction_setup["patches"] - mock_logger = patch.object(credit_pool_service_module, "logger") - - with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj: - with pytest.raises(QuotaExceededError, match="Failed to deduct credits"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - session.commit.assert_not_called() - mock_logger_obj.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_permission.py b/api/tests/unit_tests/services/test_dataset_permission.py deleted file mode 100644 index 4974d6c1ef..0000000000 --- a/api/tests/unit_tests/services/test_dataset_permission.py +++ /dev/null @@ -1,305 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum -from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError - - -class DatasetPermissionTestDataFactory: - """Factory class for creating test data and mock objects for dataset permission tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.permission = permission - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_dataset_permission_mock( - dataset_id: str = "dataset-123", - account_id: str = "user-789", - **kwargs, - ) -> Mock: - """Create a mock dataset permission record.""" - permission = Mock(spec=DatasetPermission) - permission.dataset_id = dataset_id - permission.account_id = account_id - for key, value in kwargs.items(): - setattr(permission, key, value) - return permission - - -class TestDatasetPermissionService: - """ - Comprehensive unit tests for DatasetService.check_dataset_permission method. - - This test suite covers all permission scenarios including: - - Cross-tenant access restrictions - - Owner privilege checks - - Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) - - Explicit permission checks for PARTIAL_TEAM - - Error conditions and logging - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with patch("services.dataset_service.db.session") as mock_session: - yield { - "db_session": mock_session, - } - - @pytest.fixture - def mock_logging_dependencies(self): - """Mock setup for logging tests.""" - with patch("services.dataset_service.logger") as mock_logging: - yield { - "logging": mock_logging, - } - - def _assert_permission_check_passes(self, dataset: Mock, user: Mock): - """Helper method to verify that permission check passes without raising exceptions.""" - # Should not raise any exception - DatasetService.check_dataset_permission(dataset, user) - - def _assert_permission_check_fails( - self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset." - ): - """Helper method to verify that permission check fails with expected error.""" - with pytest.raises(NoPermissionError, match=expected_message): - DatasetService.check_dataset_permission(dataset, user) - - def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str): - """Helper method to verify database query calls for permission checks.""" - mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id) - - def _assert_database_query_not_called(self, mock_session: Mock): - """Helper method to verify that database query was not called.""" - mock_session.query.assert_not_called() - - # ==================== Cross-Tenant Access Tests ==================== - - def test_permission_check_different_tenant_should_fail(self): - """Test that users from different tenants cannot access dataset regardless of other permissions.""" - # Create dataset and user from different tenants - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM - ) - user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR - ) - - # Should fail due to different tenant - self._assert_permission_check_fails(dataset, user) - - # ==================== Owner Privilege Tests ==================== - - def test_owner_can_access_any_dataset(self): - """Test that tenant owners can access any dataset regardless of permission level.""" - # Create dataset with restrictive permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) - - # Create owner user - owner_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="owner-999", role=TenantAccountRole.OWNER - ) - - # Owner should have access regardless of dataset permission - self._assert_permission_check_passes(dataset, owner_user) - - # ==================== ONLY_ME Permission Tests ==================== - - def test_only_me_permission_creator_can_access(self): - """Test ONLY_ME permission allows only the dataset creator to access.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should be able to access - self._assert_permission_check_passes(dataset, creator_user) - - def test_only_me_permission_others_cannot_access(self): - """Test ONLY_ME permission denies access to non-creators.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Non-creator should be denied access - self._assert_permission_check_fails(dataset, normal_user) - - # ==================== ALL_TEAM Permission Tests ==================== - - def test_all_team_permission_allows_access(self): - """Test ALL_TEAM permission allows any team member to access the dataset.""" - # Create dataset with ALL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM) - - # Create different types of team members - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - editor_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="editor-456", role=TenantAccountRole.EDITOR - ) - - # All team members should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_permission_check_passes(dataset, editor_user) - - # ==================== PARTIAL_TEAM Permission Tests ==================== - - def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows creator to access without database query.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should have access without database query - self._assert_permission_check_passes(dataset, creator_user) - self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"]) - - def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows users with explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return a permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=normal_user.id - ) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission - - # User with explicit permission should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission denies users without explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # User without explicit permission should be denied access - self._assert_permission_check_fails(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies): - """Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create a different user (not the creator) - other_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="other-user-123", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Non-creator without explicit permission should be denied access - self._assert_permission_check_fails(dataset, other_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id) - - # ==================== Enum Usage Tests ==================== - - def test_partial_team_permission_uses_correct_enum(self): - """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals.""" - # Create dataset with PARTIAL_TEAM permission using enum - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should always have access regardless of permission level - self._assert_permission_check_passes(dataset, creator_user) - - # ==================== Logging Tests ==================== - - def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies): - """Test that permission denied events are properly logged for debugging purposes.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Attempt permission check (should fail) - with pytest.raises(NoPermissionError): - DatasetService.check_dataset_permission(dataset, normal_user) - - # Verify debug message was logged with correct user and dataset information - mock_logging_dependencies["logging"].debug.assert_called_with( - "User %s does not have permission to access dataset %s", normal_user.id, dataset.id - ) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py deleted file mode 100644 index a1d2f6410c..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Unit tests for non-SQL DocumentService orchestration behaviors. - -This file intentionally keeps only collaborator-oriented document indexing -orchestration tests. SQL-backed dataset lifecycle cases are covered by -integration tests under testcontainers. -""" - -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Document -from services.errors.document import DocumentIndexingError - - -class DatasetServiceUnitDataFactory: - """Factory for creating lightweight document doubles used in unit tests.""" - - @staticmethod - def create_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - indexing_status: str = "completed", - is_paused: bool = False, - ) -> Mock: - """Create a document-shaped mock for DocumentService orchestration tests.""" - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.indexing_status = indexing_status - document.is_paused = is_paused - document.paused_by = None - document.paused_at = None - return document - - -class TestDatasetServiceDocumentIndexing: - """Unit tests for pause/recover/retry orchestration without SQL assertions.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Patch non-SQL collaborators used by DocumentService methods.""" - with ( - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - mock_current_user.id = "user-123" - yield { - "redis_client": mock_redis, - "db_session": mock_db, - "current_user": mock_current_user, - } - - def test_pause_document_success(self, mock_document_service_dependencies): - """Pause a document that is currently in an indexable status.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") - - # Act - from services.dataset_service import DocumentService - - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - assert document.paused_by == "user-123" - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( - f"document_{document.id}_is_paused", - "True", - ) - - def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): - """Raise DocumentIndexingError when pausing a completed document.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") - - # Act / Assert - from services.dataset_service import DocumentService - - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - def test_recover_document_success(self, mock_document_service_dependencies): - """Recover a paused document and dispatch the recover indexing task.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) - - # Act - with patch("services.dataset_service.recover_document_indexing_task") as recover_task: - from services.dataset_service import DocumentService - - DocumentService.recover_document(document) - - # Assert - assert document.is_paused is False - assert document.paused_by is None - assert document.paused_at is None - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( - f"document_{document.id}_is_paused" - ) - recover_task.delay.assert_called_once_with(document.dataset_id, document.id) - - def test_retry_document_indexing_success(self, mock_document_service_dependencies): - """Reset documents to waiting state and dispatch retry indexing task.""" - # Arrange - dataset_id = "dataset-123" - documents = [ - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), - ] - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Act - with patch("services.dataset_service.retry_document_indexing_task") as retry_task: - from services.dataset_service import DocumentService - - DocumentService.retry_document(dataset_id, documents) - - # Assert - assert all(document.indexing_status == "waiting" for document in documents) - assert mock_document_service_dependencies["db_session"].add.call_count == 2 - assert mock_document_service_dependencies["db_session"].commit.call_count == 2 - assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 - retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py deleted file mode 100644 index abff48347e..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ /dev/null @@ -1,100 +0,0 @@ -import datetime -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.dataset_service import DocumentService -from tests.unit_tests.conftest import redis_mock - - -class DocumentBatchUpdateTestDataFactory: - """Factory class for creating test data and mock objects for document batch update tests.""" - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - return dataset - - @staticmethod - def create_user_mock(user_id: str = "user-789") -> Mock: - """Create a mock user.""" - user = Mock() - user.id = user_id - return user - - @staticmethod - def create_document_mock( - document_id: str = "doc-1", - name: str = "test_document.pdf", - enabled: bool = True, - archived: bool = False, - indexing_status: str = "completed", - completed_at: datetime.datetime | None = None, - **kwargs, - ) -> Mock: - """Create a mock document with specified attributes.""" - document = Mock(spec=Document) - document.id = document_id - document.name = name - document.enabled = enabled - document.archived = archived - document.indexing_status = indexing_status - document.completed_at = completed_at or datetime.datetime.now() - - document.disabled_at = None - document.disabled_by = None - document.archived_at = None - document.archived_by = None - document.updated_at = None - - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - -class TestDatasetServiceBatchUpdateDocumentStatus: - """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Common mock setup for document service dependencies.""" - with ( - patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_document": mock_get_doc, - "db_session": mock_db, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): - """Test that ValueError is raised when an invalid action is provided.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = doc - - redis_mock.reset_mock() - redis_mock.get.return_value = None - - invalid_action = "invalid_action" - with pytest.raises(ValueError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user - ) - - assert invalid_action in str(exc_info.value) - assert "Invalid action" in str(exc_info.value) - - redis_mock.setex.assert_not_called() diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py deleted file mode 100644 index f8c5270656..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" - -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest - -from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity - - -class TestDatasetServiceCreateRagPipelineDatasetNonSQL: - """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" - - @pytest.fixture - def mock_rag_pipeline_dependencies(self): - """Patch database session and current_user for validation-only unit coverage.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - yield { - "db_session": mock_db, - "current_user_mock": mock_current_user, - } - - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Raise ValueError when current_user.id is unavailable before SQL persistence.""" - # Arrange - tenant_id = str(uuid4()) - mock_rag_pipeline_dependencies["current_user_mock"].id = None - - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="Test Dataset", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act / Assert - with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, - rag_pipeline_dataset_create_entity=entity, - ) diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index bd226f7536..9a513c3fe6 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -70,16 +71,16 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch # Minimal knowledge_config stub that satisfies pre-lock code info_list = types.SimpleNamespace(data_source_type="upload_file") data_source = types.SimpleNamespace(info_list=info_list) knowledge_config = types.SimpleNamespace( - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model=None, embedding_model_provider=None, retrieval_model=None, @@ -125,13 +126,13 @@ def test_add_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # skip embedding/token calculation branch + dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch document = create_autospec(Document, instance=True) document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX # Minimal args required by add_segment args = { @@ -168,10 +169,10 @@ def test_multi_create_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # again, skip high_quality path + dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path document = create_autospec(Document, instance=True) document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py deleted file mode 100644 index a7e1a011f6..0000000000 --- a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Unit tests for archived workflow run deletion service. -""" - -from unittest.mock import MagicMock, patch - - -class TestArchivedWorkflowRunDeletion: - def test_delete_by_run_id_calls_delete_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion() - repo = MagicMock() - repo.get_archived_run_ids.return_value = {"run-1"} - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - session = MagicMock() - session.get.return_value = run - - session_maker = MagicMock() - session_maker.return_value.__enter__.return_value = session - session_maker.return_value.__exit__.return_value = None - mock_db = MagicMock() - mock_db.engine = MagicMock() - - with ( - patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), - patch( - "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", - return_value=session_maker, - autospec=True, - ), - patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True), - patch.object( - deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True - ) as mock_delete_run, - ): - result = deleter.delete_by_run_id("run-1") - - assert result.success is True - mock_delete_run.assert_called_once_with(run) - - def test_delete_run_dry_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion(dry_run=True) - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo: - result = deleter._delete_run(run) - - assert result.success is True - mock_get_repo.assert_not_called() diff --git a/api/tests/unit_tests/services/test_document_service_display_status.py b/api/tests/unit_tests/services/test_document_service_display_status.py deleted file mode 100644 index cb2e2940c8..0000000000 --- a/api/tests/unit_tests/services/test_document_service_display_status.py +++ /dev/null @@ -1,8 +0,0 @@ -from services.dataset_service import DocumentService - - -def test_normalize_display_status_alias_mapping(): - assert DocumentService.normalize_display_status("ACTIVE") == "available" - assert DocumentService.normalize_display_status("enabled") == "available" - assert DocumentService.normalize_display_status("archived") == "archived" - assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py deleted file mode 100644 index a3b1f46436..0000000000 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ /dev/null @@ -1,841 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, DefaultEndUserSessionID, EndUser -from services.end_user_service import EndUserService - - -class TestEndUserServiceFactory: - """Factory class for creating test data and mock objects for end user service tests.""" - - @staticmethod - def create_app_mock( - app_id: str = "app-123", - tenant_id: str = "tenant-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock(spec=App) - app.id = app_id - app.tenant_id = tenant_id - app.name = name - return app - - @staticmethod - def create_end_user_mock( - user_id: str = "user-789", - tenant_id: str = "tenant-456", - app_id: str = "app-123", - session_id: str = "session-001", - type: InvokeFrom = InvokeFrom.SERVICE_API, - is_anonymous: bool = False, - ) -> MagicMock: - """Create a mock EndUser object.""" - end_user = MagicMock(spec=EndUser) - end_user.id = user_id - end_user.tenant_id = tenant_id - end_user.app_id = app_id - end_user.session_id = session_id - end_user.type = type - end_user.is_anonymous = is_anonymous - end_user.external_user_id = session_id - return end_user - - -class TestEndUserServiceGetEndUserById: - """Unit tests for EndUserService.get_end_user_by_id method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): - """Test successful retrieval of end user by ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = mock_end_user - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result == mock_end_user - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.first.assert_called_once() - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): - """Test retrieval of non-existent end user returns None.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result is None - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): - """Test that query parameters are correctly applied.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - # Verify the where clause was called with the correct conditions - call_args = mock_query.where.call_args[0] - assert len(call_args) == 3 - # Check that the conditions match the expected filters - # (We can't easily test the exact conditions without importing SQLAlchemy) - - -class TestEndUserServiceGetOrCreateEndUser: - """Unit tests for EndUserService.get_or_create_end_user method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user with specific user_id.""" - # Arrange - app_mock = factory.create_app_mock() - user_id = "user-123" - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, user_id) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id - ) - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user without user_id (None).""" - # Arrange - app_mock = factory.create_app_mock() - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, None) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None - ) - - -class TestEndUserServiceGetOrCreateEndUserByType: - """ - Unit tests for EndUserService.get_or_create_end_user_by_type method. - - This test suite covers: - - Creating end users with different InvokeFrom types - - Type migration for legacy users - - Query ordering and prioritization - - Session management - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): - """Test creating a new end user with specific user_id.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - # Verify new EndUser was created with correct parameters - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - added_user = mock_session.add.call_args[0][0] - assert added_user.tenant_id == tenant_id - assert added_user.app_id == app_id - assert added_user.type == type_enum - assert added_user.session_id == user_id - assert added_user.external_user_id == user_id - assert added_user._is_anonymous is False - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): - """Test creating a new end user with default session ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = None - type_enum = InvokeFrom.WEB_APP - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): - """Test retrieving existing user with same type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - mock_session.add.assert_not_called() - mock_session.commit.assert_not_called() - mock_logger.info.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): - """Test upgrading existing user with different type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - old_type = InvokeFrom.WEB_APP - new_type = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - assert existing_user.type == new_type - mock_session.commit.assert_called_once() - mock_logger.info.assert_called_once() - logger_call_args = mock_logger.info.call_args[0] - assert "Upgrading legacy EndUser" in logger_call_args[0] - # The old and new types are passed as separate arguments - assert mock_logger.info.call_args[0][1] == existing_user.id - assert mock_logger.info.call_args[0][2] == old_type - assert mock_logger.info.call_args[0][3] == new_type - assert mock_logger.info.call_args[0][4] == user_id - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): - """Test that query ordering prioritizes exact type matches.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - target_type = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - mock_query.order_by.assert_called_once() - # Verify that case statement is used for ordering - order_by_call = mock_query.order_by.call_args[0][0] - # The exact structure depends on SQLAlchemy's case implementation - # but we can verify it was called - - # Test 10: Session context manager properly closes - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_session_context_manager_closes(self, mock_db, mock_session_class, factory): - """Test that Session context manager is properly used.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_id=app_id, - user_id=user_id, - ) - - # Assert - # Verify context manager was entered and exited - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): - """Test that all InvokeFrom enum values are supported.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.type == invoke_type - - -class TestEndUserServiceCreateEndUserBatch: - """Unit tests for EndUserService.create_end_user_batch method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): - """Test batch creation with empty app_ids list.""" - # Arrange - tenant_id = "tenant-123" - app_ids: list[str] = [] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert result == {} - mock_session_class.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_default_session_id(self, mock_db, mock_session_class): - """Test batch creation with empty user_id (uses default session).""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - for app_id, end_user in result.items(): - assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): - """Test that duplicate app_ids are deduplicated while preserving order.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - # Should have 3 unique app_ids in original order - assert len(result) == 3 - assert "app-456" in result - assert "app-789" in result - assert "app-123" in result - - # Verify the order is preserved - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 3 - assert added_users[0].app_id == "app-456" - assert added_users[1].app_id == "app-789" - assert added_users[2].app_id == "app-123" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation when all users already exist.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - assert result["app-456"] == existing_user1 - assert result["app-789"] == existing_user2 - mock_session.add_all.assert_not_called() - mock_session.commit.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation with some existing and some new users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - # app-789 and app-123 don't exist - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 3 - assert result["app-456"] == existing_user1 - assert "app-789" in result - assert "app-123" in result - - # Should create 2 new users - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 2 - - mock_session.commit.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): - """Test batch creation handles duplicates in existing users gracefully.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Simulate duplicate records in database - existing_user1 = factory.create_end_user_mock( - user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - # Should prefer the first one found - assert result["app-456"] == existing_user1 - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): - """Test batch creation with all InvokeFrom types.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - added_user = mock_session.add_all.call_args[0][0][0] - assert added_user.type == invoke_type - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): - """Test batch creation with single app_id.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - assert "app-456" in result - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 1 - assert added_users[0].app_id == "app-456" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): - """Test batch creation correctly sets anonymous flag.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - - # Test with regular user ID - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - authenticated user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is False - - # Test with default session ID - mock_session.reset_mock() - mock_query.reset_mock() - mock_query.all.return_value = [] - - # Act - anonymous user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_ids=app_ids, - user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): - """Test that batch creation uses efficient single query for existing users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - # Should make exactly one query to check for existing users - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.all.assert_called_once() - - # Verify the where clause uses .in_() for app_ids - where_call = mock_query.where.call_args[0] - # The exact structure depends on SQLAlchemy implementation - # but we can verify it was called with the right parameters - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_session_context_manager(self, mock_db, mock_session_class): - """Test that batch creation properly uses session context manager.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_feedback_service.py b/api/tests/unit_tests/services/test_feedback_service.py deleted file mode 100644 index 1f70839ee2..0000000000 --- a/api/tests/unit_tests/services/test_feedback_service.py +++ /dev/null @@ -1,626 +0,0 @@ -import csv -import io -import json -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.feedback_service import FeedbackService - - -class TestFeedbackServiceFactory: - """Factory class for creating test data and mock objects for feedback service tests.""" - - @staticmethod - def create_feedback_mock( - feedback_id: str = "feedback-123", - app_id: str = "app-456", - conversation_id: str = "conv-789", - message_id: str = "msg-001", - rating: str = "like", - content: str | None = "Great response!", - from_source: str = "user", - from_account_id: str | None = None, - from_end_user_id: str | None = "end-user-001", - created_at: datetime | None = None, - ) -> MagicMock: - """Create a mock MessageFeedback object.""" - feedback = MagicMock() - feedback.id = feedback_id - feedback.app_id = app_id - feedback.conversation_id = conversation_id - feedback.message_id = message_id - feedback.rating = rating - feedback.content = content - feedback.from_source = from_source - feedback.from_account_id = from_account_id - feedback.from_end_user_id = from_end_user_id - feedback.created_at = created_at or datetime.now() - return feedback - - @staticmethod - def create_message_mock( - message_id: str = "msg-001", - query: str = "What is AI?", - answer: str = "AI stands for Artificial Intelligence.", - inputs: dict | None = None, - created_at: datetime | None = None, - ): - """Create a mock Message object.""" - - # Create a simple object with instance attributes - # Using a class with __init__ ensures attributes are instance attributes - class Message: - def __init__(self): - self.id = message_id - self.query = query - self.answer = answer - self.inputs = inputs - self.created_at = created_at or datetime.now() - - return Message() - - @staticmethod - def create_conversation_mock( - conversation_id: str = "conv-789", - name: str | None = "Test Conversation", - ) -> MagicMock: - """Create a mock Conversation object.""" - conversation = MagicMock() - conversation.id = conversation_id - conversation.name = name - return conversation - - @staticmethod - def create_app_mock( - app_id: str = "app-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock() - app.id = app_id - app.name = name - return app - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - name: str = "Test Admin", - ) -> MagicMock: - """Create a mock Account object.""" - account = MagicMock() - account.id = account_id - account.name = name - return account - - -class TestFeedbackService: - """ - Comprehensive unit tests for FeedbackService. - - This test suite covers: - - CSV and JSON export formats - - All filter combinations - - Edge cases and error handling - - Response validation - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestFeedbackServiceFactory() - - @pytest.fixture - def sample_feedback_data(self, factory): - """Create sample feedback data for testing.""" - feedback = factory.create_feedback_mock( - rating="like", - content="Excellent answer!", - from_source="user", - ) - message = factory.create_message_mock( - query="What is Python?", - answer="Python is a programming language.", - ) - conversation = factory.create_conversation_mock(name="Python Discussion") - app = factory.create_app_mock(name="AI Assistant") - account = factory.create_account_mock(name="Admin User") - - return [(feedback, message, conversation, app, account)] - - # Test 01: CSV Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_basic(self, mock_db, factory, sample_feedback_data): - """Test basic CSV export with single feedback record.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - assert response.mimetype == "text/csv" - assert "charset=utf-8-sig" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - assert "dify_feedback_export_app-456" in response.headers["Content-Disposition"] - - # Verify CSV content - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - - assert len(rows) == 1 - assert rows[0]["feedback_rating"] == "👍" - assert rows[0]["feedback_rating_raw"] == "like" - assert rows[0]["feedback_comment"] == "Excellent answer!" - assert rows[0]["user_query"] == "What is Python?" - assert rows[0]["ai_response"] == "Python is a programming language." - - # Test 02: JSON Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_json_basic(self, mock_db, factory, sample_feedback_data): - """Test basic JSON export with metadata structure.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - assert response.mimetype == "application/json" - assert "charset=utf-8" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - - # Verify JSON structure - json_content = json.loads(response.get_data(as_text=True)) - assert "export_info" in json_content - assert "feedback_data" in json_content - assert json_content["export_info"]["app_id"] == "app-456" - assert json_content["export_info"]["total_records"] == 1 - assert len(json_content["feedback_data"]) == 1 - - # Test 03: Filter by from_source - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_from_source(self, mock_db, factory): - """Test filtering by feedback source (user/admin).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", from_source="admin") - - # Assert - mock_query.filter.assert_called() - - # Test 04: Filter by rating - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_rating(self, mock_db, factory): - """Test filtering by rating (like/dislike).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", rating="dislike") - - # Assert - mock_query.filter.assert_called() - - # Test 05: Filter by has_comment (True) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_true(self, mock_db, factory): - """Test filtering for feedback with comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=True) - - # Assert - mock_query.filter.assert_called() - - # Test 06: Filter by has_comment (False) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_false(self, mock_db, factory): - """Test filtering for feedback without comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=False) - - # Assert - mock_query.filter.assert_called() - - # Test 07: Filter by date range - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_date_range(self, mock_db, factory): - """Test filtering by start and end dates.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - assert mock_query.filter.call_count >= 2 # Called for both start and end dates - - # Test 08: Invalid date format - start_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_start_date(self, mock_db): - """Test error handling for invalid start_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid start_date format"): - FeedbackService.export_feedbacks(app_id="app-456", start_date="invalid-date") - - # Test 09: Invalid date format - end_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_end_date(self, mock_db): - """Test error handling for invalid end_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid end_date format"): - FeedbackService.export_feedbacks(app_id="app-456", end_date="2024-13-45") - - # Test 10: Unsupported format - def test_export_feedbacks_unsupported_format(self): - """Test error handling for unsupported export format.""" - # Act & Assert - with pytest.raises(ValueError, match="Unsupported format"): - FeedbackService.export_feedbacks(app_id="app-456", format_type="xml") - - # Test 11: Empty result set - CSV - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_csv(self, mock_db): - """Test CSV export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - assert len(rows) == 0 - # But headers should still be present - assert reader.fieldnames is not None - - # Test 12: Empty result set - JSON - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_json(self, mock_db): - """Test JSON export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["export_info"]["total_records"] == 0 - assert len(json_content["feedback_data"]) == 0 - - # Test 13: Long response truncation - @patch("services.feedback_service.db") - def test_export_feedbacks_long_response_truncation(self, mock_db, factory): - """Test that long AI responses are truncated to 500 characters.""" - # Arrange - long_answer = "A" * 600 # 600 characters - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(answer=long_answer) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - ai_response = json_content["feedback_data"][0]["ai_response"] - assert len(ai_response) == 503 # 500 + "..." - assert ai_response.endswith("...") - - # Test 14: Null account (end user feedback) - @patch("services.feedback_service.db") - def test_export_feedbacks_null_account(self, mock_db, factory): - """Test handling of feedback from end users (no account).""" - # Arrange - feedback = factory.create_feedback_mock(from_account_id=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = None # No account for end user - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["from_account_name"] == "" - - # Test 15: Null conversation name - @patch("services.feedback_service.db") - def test_export_feedbacks_null_conversation_name(self, mock_db, factory): - """Test handling of conversations without names.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock() - conversation = factory.create_conversation_mock(name=None) - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["conversation_name"] == "" - - # Test 16: Dislike rating emoji - @patch("services.feedback_service.db") - def test_export_feedbacks_dislike_rating(self, mock_db, factory): - """Test that dislike rating shows thumbs down emoji.""" - # Arrange - feedback = factory.create_feedback_mock(rating="dislike") - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_rating"] == "👎" - assert json_content["feedback_data"][0]["feedback_rating_raw"] == "dislike" - - # Test 17: Combined filters - @patch("services.feedback_service.db") - def test_export_feedbacks_combined_filters(self, mock_db, factory): - """Test applying multiple filters simultaneously.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - from_source="admin", - rating="like", - has_comment=True, - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - # Should have called filter multiple times for each condition - assert mock_query.filter.call_count >= 4 - - # Test 18: Message query fallback to inputs - @patch("services.feedback_service.db") - def test_export_feedbacks_message_query_from_inputs(self, mock_db, factory): - """Test fallback to inputs.query when message.query is None.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(query=None, inputs={"query": "Query from inputs"}) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["user_query"] == "Query from inputs" - - # Test 19: Empty feedback content - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_feedback_content(self, mock_db, factory): - """Test handling of feedback with empty/null content.""" - # Arrange - feedback = factory.create_feedback_mock(content=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_comment"] == "" - assert json_content["feedback_data"][0]["has_comment"] == "No" - - # Test 20: CSV headers validation - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_headers(self, mock_db, factory, sample_feedback_data): - """Test that CSV contains all expected headers.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - expected_headers = [ - "feedback_id", - "app_name", - "app_id", - "conversation_id", - "conversation_name", - "message_id", - "user_query", - "ai_response", - "feedback_rating", - "feedback_rating_raw", - "feedback_comment", - "feedback_source", - "feedback_date", - "message_date", - "from_account_name", - "from_end_user_id", - "has_comment", - ] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - assert list(reader.fieldnames) == expected_headers diff --git a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py b/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py deleted file mode 100644 index 7b4d349e33..0000000000 --- a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Unit tests for `services.file_service.FileService` helpers. - -We keep these tests focused on: -- ZIP tempfile building (sanitization + deduplication + content writes) -- tenant-scoped batch lookup behavior (`get_upload_files_by_ids`) -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any -from zipfile import ZipFile - -import pytest - -import services.file_service as file_service_module -from services.file_service import FileService - - -def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure ZIP entry names are safe and unique while preserving extensions.""" - - # Arrange: three upload files that all sanitize down to the same basename ("b.txt"). - upload_files: list[Any] = [ - SimpleNamespace(name="a/b.txt", key="k1"), - SimpleNamespace(name="c/b.txt", key="k2"), - SimpleNamespace(name="../b.txt", key="k3"), - ] - - # Stream distinct bytes per key so we can verify content is written to the right entry. - data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} - - def _load(key: str, stream: bool = True) -> list[bytes]: - # Return the corresponding chunks for this key (the production code iterates chunks). - assert stream is True - return data_by_key[key] - - monkeypatch.setattr(file_service_module.storage, "load", _load) - - # Act: build zip in a tempfile. - with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: - with ZipFile(tmp, mode="r") as zf: - # Assert: names are sanitized (no directory components) and deduped with suffixes. - assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] - - # Assert: each entry contains the correct bytes from storage. - assert zf.read("b.txt") == b"one" - assert zf.read("b (1).txt") == b"two" - assert zf.read("b (2).txt") == b"three" - - -def test_get_upload_files_by_ids_returns_empty_when_no_ids(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure empty input returns an empty mapping without hitting the database.""" - - class _Session: - def scalars(self, _stmt): # type: ignore[no-untyped-def] - raise AssertionError("db.session.scalars should not be called for empty id lists") - - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=_Session())) - - assert FileService.get_upload_files_by_ids("tenant-1", []) == {} - - -def test_get_upload_files_by_ids_returns_id_keyed_mapping(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" - - upload_files: list[Any] = [ - SimpleNamespace(id="file-1", tenant_id="tenant-1"), - SimpleNamespace(id="file-2", tenant_id="tenant-1"), - ] - - class _ScalarResult: - def __init__(self, items: list[Any]) -> None: - self._items = items - - def all(self) -> list[Any]: - return self._items - - class _Session: - def __init__(self, items: list[Any]) -> None: - self._items = items - self.calls: list[object] = [] - - def scalars(self, stmt): # type: ignore[no-untyped-def] - # Capture the statement so we can at least assert the query path is taken. - self.calls.append(stmt) - return _ScalarResult(self._items) - - session = _Session(upload_files) - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=session)) - - # Provide duplicates to ensure callers can safely pass repeated ids. - result = FileService.get_upload_files_by_ids("tenant-1", ["file-1", "file-1", "file-2"]) - - assert set(result.keys()) == {"file-1", "file-2"} - assert result["file-1"].id == "file-1" - assert result["file-2"].id == "file-2" - assert len(session.calls) == 1 diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index a23c44b26e..3b1c1fcf17 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -313,7 +313,8 @@ class TestEmailDeliveryTestHandler: recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], ) - subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") assert subs["node_title"] == "title" assert subs["form_content"] == "content" diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 4b8bdde46b..e7740ef93a 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, EndUser, Message from services.errors.message import ( FirstMessageNotExistsError, @@ -820,14 +821,14 @@ class TestMessageServiceFeedback: app_model=app, message_id="msg-123", user=user, - rating="like", + rating=FeedbackRating.LIKE, content="Good answer", ) # Assert - assert result.rating == "like" + assert result.rating == FeedbackRating.LIKE assert result.content == "Good answer" - assert result.from_source == "user" + assert result.from_source == FeedbackFromSource.USER mock_db.session.add.assert_called_once() mock_db.session.commit.assert_called_once() @@ -852,13 +853,13 @@ class TestMessageServiceFeedback: app_model=app, message_id="msg-123", user=user, - rating="dislike", + rating=FeedbackRating.DISLIKE, content="Bad answer", ) # Assert assert result == feedback - assert feedback.rating == "dislike" + assert feedback.rating == FeedbackRating.DISLIKE assert feedback.content == "Bad answer" mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_metadata_service.py b/api/tests/unit_tests/services/test_metadata_service.py new file mode 100644 index 0000000000..bbdc16d4f8 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_service.py @@ -0,0 +1,558 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataArgs, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +@dataclass +class _DocumentStub: + id: str + name: str + uploader: str + upload_date: datetime + last_update_date: datetime + data_source_type: str + doc_metadata: dict[str, object] | None + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + mocked_db = mocker.patch("services.metadata_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.fixture +def mock_redis_client(mocker: MockerFixture) -> MagicMock: + return mocker.patch("services.metadata_service.redis_client") + + +@pytest.fixture +def mock_current_account(mocker: MockerFixture) -> MagicMock: + mock_user = SimpleNamespace(id="user-1") + return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1")) + + +def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub: + now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC) + return _DocumentStub( + id=document_id, + name=f"doc-{document_id}", + uploader="qa@example.com", + upload_date=now, + last_update_date=now, + data_source_type="upload_file", + doc_metadata=doc_metadata, + ) + + +def _dataset(**kwargs: Any) -> Dataset: + return cast(Dataset, SimpleNamespace(**kwargs)) + + +def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name="x" * 256) + + # Act + Assert + with pytest.raises(ValueError, match="cannot exceed 255"): + MetadataService.create_metadata("dataset-1", metadata_args) + + +def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists( + mock_db: MagicMock, + mock_current_account: MagicMock, +) -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name="priority") + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + MetadataService.create_metadata("dataset-1", metadata_args) + + # Assert + mock_current_account.assert_called_once() + + +def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Built-in fields"): + MetadataService.create_metadata("dataset-1", metadata_args) + + +def test_create_metadata_should_persist_metadata_when_input_is_valid( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + metadata_args = MetadataArgs(type="number", name="score") + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = MetadataService.create_metadata("dataset-1", metadata_args) + + # Assert + assert result.tenant_id == "tenant-1" + assert result.dataset_id == "dataset-1" + assert result.type == "number" + assert result.name == "score" + assert result.created_by == "user-1" + mock_db.session.add.assert_called_once_with(result) + mock_db.session.commit.assert_called_once() + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None: + # Arrange + too_long_name = "x" * 256 + + # Act + Assert + with pytest.raises(ValueError, match="cannot exceed 255"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name) + + +def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate") + + # Assert + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin( + mock_db: MagicMock, + mock_current_account: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Built-in fields"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source) + + # Assert + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_update_bound_documents_and_return_metadata( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC) + mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now) + + metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None) + bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] + query_duplicate = MagicMock() + query_duplicate.filter_by.return_value.first.return_value = None + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = metadata + query_bindings = MagicMock() + query_bindings.filter_by.return_value.all.return_value = bindings + mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings] + + doc_1 = _build_document("1", {"old_name": "value", "other": "keep"}) + doc_2 = _build_document("2", None) + mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids") + mock_get_documents.return_value = [doc_1, doc_2] + + # Act + result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name") + + # Assert + assert result is metadata + assert metadata.name == "new_name" + assert metadata.updated_by == "user-1" + assert metadata.updated_at == fixed_now + assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"} + assert doc_2.doc_metadata == {"new_name": None} + mock_get_documents.assert_called_once_with(["doc-1", "doc-2"]) + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_return_none_when_metadata_does_not_exist( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + mock_logger = mocker.patch("services.metadata_service.logger") + + query_duplicate = MagicMock() + query_duplicate.filter_by.return_value.first.return_value = None + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = None + mock_db.session.query.side_effect = [query_duplicate, query_metadata] + + # Act + result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name") + + # Assert + assert result is None + mock_logger.exception.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + mock_current_account.assert_called_once() + + +def test_delete_metadata_should_remove_metadata_and_related_document_fields( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + metadata = SimpleNamespace(id="metadata-1", name="obsolete") + bindings = [SimpleNamespace(document_id="doc-1")] + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = metadata + query_bindings = MagicMock() + query_bindings.filter_by.return_value.all.return_value = bindings + mock_db.session.query.side_effect = [query_metadata, query_bindings] + + document = _build_document("1", {"obsolete": "legacy", "remaining": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document]) + + # Act + result = MetadataService.delete_metadata("dataset-1", "metadata-1") + + # Assert + assert result is metadata + assert document.doc_metadata == {"remaining": "value"} + mock_db.session.delete.assert_called_once_with(metadata) + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_delete_metadata_should_return_none_when_metadata_is_missing( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_logger = mocker.patch("services.metadata_service.logger") + + # Act + result = MetadataService.delete_metadata("dataset-1", "missing-id") + + # Assert + assert result is None + mock_logger.exception.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_get_built_in_fields_should_return_all_expected_fields() -> None: + # Arrange + expected_names = { + BuiltInField.document_name, + BuiltInField.uploader, + BuiltInField.upload_date, + BuiltInField.last_update_date, + BuiltInField.source, + } + + # Act + result = MetadataService.get_built_in_fields() + + # Assert + assert {item["name"] for item in result} == expected_names + assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"] + + +def test_enable_built_in_field_should_return_immediately_when_already_enabled( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") + + # Act + MetadataService.enable_built_in_field(dataset) + + # Assert + get_docs.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_enable_built_in_field_should_populate_documents_and_enable_flag( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + doc_1 = _build_document("1", {"custom": "value"}) + doc_2 = _build_document("2", None) + mocker.patch( + "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", + return_value=[doc_1, doc_2], + ) + + # Act + MetadataService.enable_built_in_field(dataset) + + # Assert + assert dataset.built_in_field_enabled is True + assert doc_1.doc_metadata is not None + assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1" + assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file + assert doc_2.doc_metadata is not None + assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com" + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_disable_built_in_field_should_return_immediately_when_already_disabled( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") + + # Act + MetadataService.disable_built_in_field(dataset) + + # Assert + get_docs.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + document = _build_document( + "1", + { + BuiltInField.document_name: "doc", + BuiltInField.uploader: "user", + BuiltInField.upload_date: 1.0, + BuiltInField.last_update_date: 2.0, + BuiltInField.source: MetadataDataSource.upload_file, + "custom": "keep", + }, + ) + mocker.patch( + "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", + return_value=[document], + ) + + # Act + MetadataService.disable_built_in_field(dataset) + + # Assert + assert dataset.built_in_field_enabled is False + assert document.doc_metadata == {"custom": "keep"} + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + document = _build_document("1", {"legacy": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) + delete_chain = mock_db.session.query.return_value.filter_by.return_value + delete_chain.delete.return_value = 1 + operation = DocumentMetadataOperation( + document_id="1", + metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + assert document.doc_metadata == {"priority": "high"} + delete_chain.delete.assert_called_once() + assert mock_db.session.commit.call_count == 1 + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") + mock_current_account.assert_called_once() + + +def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + document = _build_document("1", {"existing": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + operation = DocumentMetadataOperation( + document_id="1", + metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + assert document.doc_metadata is not None + assert document.doc_metadata["existing"] == "value" + assert document.doc_metadata["new_key"] == "new_value" + assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file + assert mock_db.session.commit.call_count == 1 + assert mock_db.session.add.call_count == 1 + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") + mock_current_account.assert_called_once() + + +def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None) + operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + Assert + with pytest.raises(ValueError, match="Document not found"): + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + mock_db.session.rollback.assert_called_once() + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404") + + +@pytest.mark.parametrize( + ("dataset_id", "document_id", "expected_key"), + [ + ("dataset-1", None, "dataset_metadata_lock_dataset-1"), + (None, "doc-1", "document_metadata_lock_doc-1"), + ], +) +def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked( + dataset_id: str | None, + document_id: str | None, + expected_key: str, + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id) + + # Assert + mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600) + + +def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = 1 + + # Act + Assert + with pytest.raises(ValueError, match="knowledge base metadata operation is running"): + MetadataService.knowledge_base_metadata_lock_check("dataset-1", None) + + +def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = 1 + + # Act + Assert + with pytest.raises(ValueError, match="document metadata operation is running"): + MetadataService.knowledge_base_metadata_lock_check(None, "doc-1") + + +def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None: + # Arrange + dataset = _dataset( + id="dataset-1", + built_in_field_enabled=True, + doc_metadata=[ + {"id": "meta-1", "name": "priority", "type": "string"}, + {"id": "built-in", "name": "ignored", "type": "string"}, + {"id": "meta-2", "name": "score", "type": "number"}, + ], + ) + count_chain = mock_db.session.query.return_value.filter_by.return_value + count_chain.count.side_effect = [3, 1] + + # Act + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert + assert result["built_in_field_enabled"] is True + assert result["doc_metadata"] == [ + {"id": "meta-1", "name": "priority", "type": "string", "count": 3}, + {"id": "meta-2", "name": "score", "type": "number", "count": 1}, + ] + + +def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None) + + # Act + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert + assert result == {"doc_metadata": [], "built_in_field_enabled": False} + mock_db.session.query.assert_not_called() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py new file mode 100644 index 0000000000..49e572584b --- /dev/null +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,808 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, +) +from models.provider import LoadBalancingModelConfig +from services.model_load_balancing_service import ModelLoadBalancingService + + +def _build_provider_credential_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ] + ) + + +def _build_model_credential_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ], + ) + + +def _build_provider_configuration( + *, + custom_provider: bool = False, + load_balancing_enabled: bool | None = None, + model_schema: ModelCredentialSchema | None = None, + provider_schema: ProviderCredentialSchema | None = None, +) -> MagicMock: + provider_configuration = MagicMock() + provider_configuration.provider = SimpleNamespace( + provider="openai", + model_credential_schema=model_schema, + provider_credential_schema=provider_schema, + ) + provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider) + provider_configuration.extract_secret_variables.return_value = ["api_key"] + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials + provider_configuration.get_provider_model_setting.return_value = ( + None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled) + ) + return provider_configuration + + +def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: + return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs)) + + +@pytest.fixture +def service(mocker: MockerFixture) -> ModelLoadBalancingService: + # Arrange + provider_manager = MagicMock() + mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) + svc = ModelLoadBalancingService() + svc.provider_manager = provider_manager + return svc + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.model_load_balancing_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.mark.parametrize( + ("method_name", "expected_provider_method"), + [ + ("enable_model_load_balancing", "enable_model_load_balancing"), + ("disable_model_load_balancing", "disable_model_load_balancing"), + ], +) +def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists( + method_name: str, + expected_provider_method: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + # Assert + getattr(provider_configuration, expected_provider_method).assert_called_once_with( + model="gpt-4o-mini", model_type=ModelType.LLM + ) + + +@pytest.mark.parametrize( + "method_name", + ["enable_model_load_balancing", "disable_model_load_balancing"], +) +def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing( + method_name: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=True, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace( + id="cfg-1", + name="primary", + encrypted_config=json.dumps({"api_key": "encrypted-key"}), + credential_id="cred-1", + enabled=True, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + return_value="plain-key", + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(False, 0), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + ) + + # Assert + assert is_enabled is True + assert len(configs) == 2 + assert configs[0]["name"] == "__inherit__" + assert configs[1]["name"] == "primary" + assert configs[1]["credentials"] == {"api_key": "plain-key"} + assert mock_db.session.add.call_count == 1 + assert mock_db.session.commit.call_count == 1 + + +def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=None, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + normal_config = SimpleNamespace( + id="cfg-1", + name="normal", + encrypted_config=json.dumps({"api_key": "bad-encrypted"}), + credential_id="cred-1", + enabled=True, + ) + inherit_config = SimpleNamespace( + id="cfg-2", + name="__inherit__", + encrypted_config="not-json", + credential_id=None, + enabled=False, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + normal_config, + inherit_config, + ] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + side_effect=ValueError("cannot decrypt"), + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(True, 15), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + config_from="predefined-model", + ) + + # Assert + assert is_enabled is False + assert configs[0]["name"] == "__inherit__" + assert configs[0]["credentials"] == {} + assert configs[1]["credentials"] == {"api_key": "bad-encrypted"} + assert configs[1]["in_cooldown"] is True + assert configs[1]["ttl"] == 15 + + +def test_get_load_balancing_config_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + +def test_get_load_balancing_config_should_return_none_when_config_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result is None + + +def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: { + "masked": credentials.get("api_key", "") + } + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True) + mock_db.session.query.return_value.where.return_value.first.return_value = config + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result == { + "id": "cfg-1", + "name": "primary", + "credentials": {"masked": ""}, + "enabled": True, + } + + +def test_init_inherit_config_should_create_and_persist_inherit_configuration( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + model_type = ModelType.LLM + + # Act + inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type) + + # Assert + assert inherit_config.tenant_id == "tenant-1" + assert inherit_config.provider_name == "openai" + assert inherit_config.model_name == "gpt-4o-mini" + assert inherit_config.model_type == "text-generation" + assert inherit_config.name == "__inherit__" + mock_db.session.add.assert_called_once_with(inherit_config) + mock_db.session.commit.assert_called_once() + + +def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list( + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing configs"): + service.update_load_balancing_configs( # type: ignore[arg-type] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], "invalid-configs"), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config"): + service.update_load_balancing_configs( # type: ignore[list-item] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], ["bad-item"]), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"enabled": True}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config enabled"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "cfg-without-enabled"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + current_config = SimpleNamespace(id="cfg-1") + mock_db.session.scalars.return_value.all.return_value = [current_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-2", "name": "invalid", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None) + mock_db.session.scalars.return_value.all.return_value = [existing_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new-config", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config_1 = SimpleNamespace( + id="cfg-1", + name="existing-one", + enabled=True, + encrypted_config=json.dumps({"api_key": "old"}), + updated_at=None, + ) + existing_config_2 = SimpleNamespace( + id="cfg-2", + name="existing-two", + enabled=True, + encrypted_config=None, + updated_at=None, + ) + mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2] + mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"}) + mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache") + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [ + {"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}}, + {"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}}, + ], + "custom-model", + ) + + # Assert + assert existing_config_1.name == "updated-name" + assert existing_config_1.enabled is False + assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"} + assert mock_db.session.add.call_count == 1 + mock_db.session.delete.assert_called_once_with(existing_config_2) + assert mock_db.session.commit.call_count >= 3 + mock_clear_cache.assert_any_call("tenant-1", "cfg-1") + mock_clear_cache.assert_any_call("tenant-1", "cfg-2") + + +def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}') + mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + # Assert + created_config = mock_db.session.add.call_args.args[0] + assert created_config.name == "Main Credential" + assert created_config.credential_id == "cred-1" + assert created_config.credential_source_type == "provider" + assert created_config.encrypted_config == '{"api_key":"enc"}' + mock_db.session.commit.assert_called() + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + + +def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1") + mock_db.session.query.return_value.where.return_value.first.return_value = existing_config + mock_validate = mocker.patch.object(service, "_custom_credentials_validate") + + # Act + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + # Assert + assert mock_validate.call_count == 2 + assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config + assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + + +def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + load_balancing_model_config = _load_balancing_model_config( + encrypted_config=json.dumps({"api_key": "old-encrypted-token"}) + ) + mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value") + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + load_balancing_model_config=load_balancing_model_config, + validate=False, + ) + + # Assert + assert result == {"api_key": "enc:old-plain-value", "region": "us"} + mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value") + + +def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema()) + load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") + mock_factory = MagicMock() + mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} + mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + load_balancing_model_config=load_balancing_model_config, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:validated"} + mock_factory.model_credentials_validate.assert_called_once() + mock_factory.provider_credentials_validate.assert_not_called() + mock_encrypt.assert_called_once_with("tenant-1", "validated") + + +def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + mock_factory = MagicMock() + mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} + mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:provider-validated"} + mock_factory.provider_credentials_validate.assert_called_once() + mock_factory.model_credentials_validate.assert_not_called() + + +def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise( + service: ModelLoadBalancingService, +) -> None: + # Arrange + model_schema = _build_model_credential_schema() + provider_schema = _build_provider_credential_schema() + provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema) + provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema) + provider_configuration_without_schema = _build_provider_configuration() + + # Act + schema_from_model = service._get_credential_schema(provider_configuration_with_model) + schema_from_provider = service._get_credential_schema(provider_configuration_with_provider) + + # Assert + assert schema_from_model is model_schema + assert schema_from_provider is provider_schema + with pytest.raises(ValueError, match="No credential schema found"): + service._get_credential_schema(provider_configuration_without_schema) + + +def test_clear_credentials_cache_should_delete_load_balancing_cache_entry( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_cache_instance = MagicMock() + mock_cache_cls = mocker.patch( + "services.model_load_balancing_service.ProviderCredentialsCache", + return_value=mock_cache_instance, + ) + + # Act + service._clear_credentials_cache("tenant-1", "cfg-1") + + # Assert + mock_cache_cls.assert_called_once() + assert mock_cache_cls.call_args.kwargs == { + "tenant_id": "tenant-1", + "identity_id": "cfg-1", + "cache_type": mocker.ANY, + } + assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL" + mock_cache_instance.delete.assert_called_once() diff --git a/api/tests/unit_tests/services/test_saved_message_service.py b/api/tests/unit_tests/services/test_saved_message_service.py deleted file mode 100644 index 87b946fe46..0000000000 --- a/api/tests/unit_tests/services/test_saved_message_service.py +++ /dev/null @@ -1,626 +0,0 @@ -""" -Comprehensive unit tests for SavedMessageService. - -This test suite provides complete coverage of saved message operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -## Test Coverage - -### 1. Pagination (TestSavedMessageServicePagination) -Tests saved message listing and pagination: -- Pagination with valid user (Account and EndUser) -- Pagination without user raises ValueError -- Pagination with last_id parameter -- Empty results when no saved messages exist -- Integration with MessageService pagination - -### 2. Save Operations (TestSavedMessageServiceSave) -Tests saving messages: -- Save message for Account user -- Save message for EndUser -- Save without user (no-op) -- Prevent duplicate saves (idempotent) -- Message validation through MessageService - -### 3. Delete Operations (TestSavedMessageServiceDelete) -Tests deleting saved messages: -- Delete saved message for Account user -- Delete saved message for EndUser -- Delete without user (no-op) -- Delete non-existent saved message (no-op) -- Proper database cleanup - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked - for fast, isolated unit tests -- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**User Types:** -- Account: Workspace members (console users) -- EndUser: API users (end users) - -**Saved Messages:** -- Users can save messages for later reference -- Each user has their own saved message list -- Saving is idempotent (duplicate saves ignored) -- Deletion is safe (non-existent deletes ignored) -""" - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest - -from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models import Account -from models.model import App, EndUser, Message -from models.web import SavedMessage -from services.saved_message_service import SavedMessageService - - -class SavedMessageServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - saved message operations. - """ - - @staticmethod - def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock: - """ - Create a mock Account object. - - Args: - account_id: Unique identifier for the account - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Account object with specified attributes - """ - account = create_autospec(Account, instance=True) - account.id = account_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock: - """ - Create a mock EndUser object. - - Args: - user_id: Unique identifier for the end user - **kwargs: Additional attributes to set on the mock - - Returns: - Mock EndUser object with specified attributes - """ - user = create_autospec(EndUser, instance=True) - user.id = user_id - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant/workspace identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - """ - app = create_autospec(App, instance=True) - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - app.mode = kwargs.get("mode", "chat") - for key, value in kwargs.items(): - setattr(app, key, value) - return app - - @staticmethod - def create_message_mock( - message_id: str = "msg-123", - app_id: str = "app-123", - **kwargs, - ) -> Mock: - """ - Create a mock Message object. - - Args: - message_id: Unique identifier for the message - app_id: Associated app identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Message object with specified attributes - """ - message = create_autospec(Message, instance=True) - message.id = message_id - message.app_id = app_id - message.query = kwargs.get("query", "Test query") - message.answer = kwargs.get("answer", "Test answer") - message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(message, key, value) - return message - - @staticmethod - def create_saved_message_mock( - saved_message_id: str = "saved-123", - app_id: str = "app-123", - message_id: str = "msg-123", - created_by: str = "user-123", - created_by_role: str = "account", - **kwargs, - ) -> Mock: - """ - Create a mock SavedMessage object. - - Args: - saved_message_id: Unique identifier for the saved message - app_id: Associated app identifier - message_id: Associated message identifier - created_by: User who saved the message - created_by_role: Role of the user ('account' or 'end_user') - **kwargs: Additional attributes to set on the mock - - Returns: - Mock SavedMessage object with specified attributes - """ - saved_message = create_autospec(SavedMessage, instance=True) - saved_message.id = saved_message_id - saved_message.app_id = app_id - saved_message.message_id = message_id - saved_message.created_by = created_by - saved_message.created_by_role = created_by_role - saved_message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(saved_message, key, value) - return saved_message - - -@pytest.fixture -def factory(): - """Provide the test data factory to all tests.""" - return SavedMessageServiceTestDataFactory - - -class TestSavedMessageServicePagination: - """Test saved message pagination operations.""" - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Create saved messages for this user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="account", - ) - for i in range(3) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - mock_db_session.query.assert_called_once_with(SavedMessage) - # Verify MessageService was called with correct message IDs - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=["msg-0", "msg-1", "msg-2"], - ) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - - # Create saved messages for this end user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="end_user", - ) - for i in range(2) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10) - - # Assert - assert result == expected_pagination - # Verify correct role was used in query - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=10, - include_ids=["msg-0", "msg-1"], - ) - - def test_pagination_without_user_raises_error(self, factory): - """Test that pagination without user raises ValueError.""" - # Arrange - app = factory.create_app_mock() - - # Act & Assert - with pytest.raises(ValueError, match="User is required"): - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with last_id parameter.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - last_id = "msg-last" - - saved_messages = [ - factory.create_saved_message_mock( - message_id=f"msg-{i}", - app_id=app.id, - created_by=user.id, - ) - for i in range(5) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10) - - # Assert - assert result == expected_pagination - # Verify last_id was passed to MessageService - mock_message_pagination.assert_called_once() - call_args = mock_message_pagination.call_args - assert call_args.kwargs["last_id"] == last_id - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory): - """Test pagination when user has no saved messages.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Mock database query returning empty list - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - # Verify MessageService was called with empty include_ids - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=[], - ) - - -class TestSavedMessageServiceSave: - """Test save message operations.""" - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_account(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock(message_id="msg-123", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "account" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message = factory.create_message_mock(message_id="msg-456", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "end_user" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_without_user_does_nothing(self, mock_db_session, factory): - """Test that saving without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.save(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory): - """Test that saving an already saved message is idempotent.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-789" - - # Mock database query - existing saved message found - existing_saved = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_saved - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message_id) - - # Assert - no new saved message created - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - mock_get_message.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory): - """Test that save validates message exists through MessageService.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock() - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - MessageService.get_message was called for validation - mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id) - - -class TestSavedMessageServiceDelete: - """Test delete saved message operations.""" - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_account(self, mock_db_session, factory): - """Test deleting a saved message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-123" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_end_user(self, mock_db_session, factory): - """Test deleting a saved message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message_id = "msg-456" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="end_user", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_without_user_does_nothing(self, mock_db_session, factory): - """Test that deleting without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.delete(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory): - """Test that deleting a non-existent saved message is a no-op.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-nonexistent" - - # Mock database query - no saved message found - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - no deletion occurred - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory): - """Test that delete only removes the user's own saved message.""" - # Arrange - app = factory.create_app_mock() - user1 = factory.create_account_mock(account_id="user-1") - message_id = "msg-shared" - - # Mock database query - finds user1's saved message - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user1.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user1, message_id=message_id) - - # Assert - only user1's saved message is deleted - mock_db_session.delete.assert_called_once_with(saved_message) - # Verify the query filters by user - assert mock_query.where.called diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index be64e431ba..ef53df9350 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -26,7 +27,7 @@ class _SessionContext: return None -def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: +def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock: dataset = MagicMock(name="dataset") dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" @@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: if has_document: doc = MagicMock(name="document") doc.doc_language = "en" - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX segment.document = doc else: segment.document = None @@ -168,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: vector_cls = MagicMock() monkeypatch.setattr(summary_module, "Vector", vector_cls) - SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset) vector_cls.assert_not_called() @@ -620,16 +622,16 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _dataset(indexing_technique="economy") + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] dataset = _dataset() assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] @@ -637,7 +639,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg1 = _segment() seg2 = _segment() @@ -673,7 +675,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() query = MagicMock() @@ -696,7 +698,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg = _segment() session = MagicMock() @@ -777,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo def test_enable_summaries_for_segments_skips_non_high_quality() -> None: - SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY)) def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: @@ -931,11 +933,10 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon def test_update_summary_for_segment_skip_conditions() -> None: - assert ( - SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None - ) + economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None seg = _segment(has_document=True) - seg.document.doc_form = "qa_model" + seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 264eac4d77..b09463b1bc 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -75,6 +75,7 @@ import pytest from werkzeug.exceptions import NotFound from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TagServiceTestDataFactory: def create_tag_mock( tag_id: str = "tag-123", name: str = "Test Tag", - tag_type: str = "app", + tag_type: TagType = TagType.APP, tenant_id: str = "tenant-123", **kwargs, ) -> Mock: @@ -315,7 +316,7 @@ class TestTagServiceRetrieval: - get_tags_by_target_id: Get all tags bound to a specific target """ - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_with_binding_counts(self, mock_db_session, factory): """ Test retrieving tags with their binding counts. @@ -372,7 +373,7 @@ class TestTagServiceRetrieval: # Verify database query was called mock_db_session.query.assert_called_once() - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_with_keyword_filter(self, mock_db_session, factory): """ Test retrieving tags filtered by keyword (case-insensitive). @@ -426,7 +427,7 @@ class TestTagServiceRetrieval: # 2. Additional WHERE clause for keyword filtering assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): """ Test retrieving target IDs by tag IDs. @@ -482,7 +483,7 @@ class TestTagServiceRetrieval: # Verify both queries were executed assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): """ Test that empty tag_ids returns empty list. @@ -510,7 +511,7 @@ class TestTagServiceRetrieval: assert results == [], "Should return empty list for empty input" mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_by_tag_name(self, mock_db_session, factory): """ Test retrieving tags by name. @@ -552,7 +553,7 @@ class TestTagServiceRetrieval: assert len(results) == 1, "Should find exactly one tag" assert results[0].name == tag_name, "Tag name should match" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): """ Test that missing tag_type or tag_name returns empty list. @@ -580,7 +581,7 @@ class TestTagServiceRetrieval: # Verify no database queries were executed mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_by_target_id(self, mock_db_session, factory): """ Test retrieving tags associated with a specific target. @@ -653,7 +654,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") @patch("services.tag_service.uuid.uuid4", autospec=True) def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ @@ -705,7 +706,7 @@ class TestTagServiceCRUD: # Verify tag attributes added_tag = mock_db_session.add.call_args[0][0] assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == "app", "Tag type should match" + assert added_tag.type == TagType.APP, "Tag type should match" assert added_tag.created_by == "user-123", "Created by should match current user" assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" @@ -742,7 +743,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ Test updating a tag name. @@ -794,7 +795,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags_raises_error_for_duplicate_name( self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory ): @@ -826,7 +827,7 @@ class TestTagServiceCRUD: with pytest.raises(ValueError, match="Tag name already exists"): TagService.update_tags(args, tag_id="tag-123") - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): """ Test that updating a non-existent tag raises NotFound. @@ -858,7 +859,7 @@ class TestTagServiceCRUD: with pytest.raises(NotFound, match="Tag not found"): TagService.update_tags(args, tag_id="nonexistent") - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_binding_count(self, mock_db_session, factory): """ Test getting the count of bindings for a tag. @@ -894,7 +895,7 @@ class TestTagServiceCRUD: # Verify count matches expectation assert result == expected_count, "Binding count should match" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag(self, mock_db_session, factory): """ Test deleting a tag and its bindings. @@ -950,7 +951,7 @@ class TestTagServiceCRUD: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_raises_not_found(self, mock_db_session, factory): """ Test that deleting a non-existent tag raises NotFound. @@ -998,7 +999,7 @@ class TestTagServiceBindings: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test creating tag bindings. @@ -1049,7 +1050,7 @@ class TestTagServiceBindings: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test that saving duplicate bindings is idempotent. @@ -1089,7 +1090,7 @@ class TestTagServiceBindings: mock_db_session.add.assert_not_called(), "Should not create duplicate binding" @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): """ Test deleting a tag binding. @@ -1137,7 +1138,7 @@ class TestTagServiceBindings: mock_db_session.commit.assert_called_once(), "Should commit transaction" @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): """ Test that deleting a non-existent binding is a no-op. @@ -1174,7 +1175,7 @@ class TestTagServiceBindings: mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): """ Test validating that a dataset target exists. @@ -1215,7 +1216,7 @@ class TestTagServiceBindings: mock_db_session.query.assert_called_once(), "Should query database for dataset" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): """ Test validating that an app target exists. @@ -1256,7 +1257,7 @@ class TestTagServiceBindings: mock_db_session.query.assert_called_once(), "Should query database for app" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_raises_not_found_for_missing_dataset( self, mock_db_session, mock_current_user, factory ): @@ -1288,7 +1289,7 @@ class TestTagServiceBindings: TagService.check_target_exists("knowledge", "nonexistent") @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): """ Test that missing app raises NotFound. diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py new file mode 100644 index 0000000000..81a3b181fd --- /dev/null +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -0,0 +1,1249 @@ +from __future__ import annotations + +import contextlib +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from core.plugin.entities.plugin_daemon import CredentialType +from models.provider_ids import TriggerProviderID +from services.trigger.trigger_provider_service import TriggerProviderService + + +def _patch_redis_lock(mocker: MockerFixture) -> None: + mock_redis = mocker.patch("services.trigger.trigger_provider_service.redis_client") + mock_redis.lock.return_value = contextlib.nullcontext() + + +def _mock_get_trigger_provider(mocker: MockerFixture, provider: object | None) -> None: + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.get_trigger_provider", + return_value=provider, + ) + + +def _encrypter_mock( + *, + decrypted: dict | None = None, + encrypted: dict | None = None, + masked: dict | None = None, +) -> MagicMock: + enc = MagicMock() + enc.decrypt.return_value = decrypted or {} + enc.encrypt.return_value = encrypted or {} + enc.mask_credentials.return_value = masked or {} + enc.mask_plugin_credentials.return_value = masked or {} + return enc + + +@pytest.fixture +def provider_id() -> TriggerProviderID: + # Arrange + return TriggerProviderID("langgenius/github/github") + + +@pytest.fixture(autouse=True) +def mock_db_engine(mocker: MockerFixture) -> SimpleNamespace: + # Arrange + mocked_db = SimpleNamespace(engine=object()) + mocker.patch("services.trigger.trigger_provider_service.db", mocked_db) + return mocked_db + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> MagicMock: + """Mocks the database session context manager used by TriggerProviderService.""" + # Arrange + mock_session_instance = MagicMock() + mock_session_cm = MagicMock() + mock_session_cm.__enter__.return_value = mock_session_instance + mock_session_cm.__exit__.return_value = False + mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm) + return mock_session_instance + + +@pytest.fixture +def provider_controller() -> MagicMock: + # Arrange + controller = MagicMock() + controller.get_credential_schema_config.return_value = [] + controller.get_properties_schema.return_value = [] + controller.get_oauth_client_schema.return_value = [] + controller.plugin_unique_identifier = "langgenius/github:0.0.1" + return controller + + +def test_get_trigger_provider_should_return_api_entity_from_manager( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + provider = MagicMock() + provider.to_api_entity.return_value = {"provider": "ok"} + _mock_get_trigger_provider(mocker, provider) + + # Act + result = TriggerProviderService.get_trigger_provider("tenant-1", provider_id) + + # Assert + assert result == {"provider": "ok"} + + +def test_list_trigger_providers_should_return_api_entities_from_manager(mocker: MockerFixture) -> None: + # Arrange + provider_a = MagicMock() + provider_b = MagicMock() + provider_a.to_api_entity.return_value = {"id": "a"} + provider_b.to_api_entity.return_value = {"id": "b"} + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.list_all_trigger_providers", + return_value=[provider_a, provider_b], + ) + + # Act + result = TriggerProviderService.list_trigger_providers("tenant-1") + + # Assert + assert result == [{"id": "a"}, {"id": "b"}] + + +def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_subscriptions( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.order_by.return_value.all.return_value = [] + mock_session.query.return_value = query + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert result == [] + + +def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workflow_counts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + api_sub = SimpleNamespace( + id="sub-1", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + parameters={"event": "push"}, + workflows_in_use=0, + ) + db_sub = SimpleNamespace(to_api_entity=lambda: api_sub) + usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2) + + query_subs = MagicMock() + query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub] + query_usage = MagicMock() + query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row] + mock_session.query.side_effect = [query_subs, query_usage] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"}) + prop_enc = _encrypter_mock(decrypted={"hook": "plain"}, masked={"hook": "****"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert len(result) == 1 + assert result[0].credentials == {"token": "****"} + assert result[0].properties == {"hook": "****"} + assert result[0].workflows_in_use == 2 + + +def test_add_trigger_subscription_should_create_subscription_successfully_for_api_key( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(encrypted={"api_key": "enc"}) + prop_enc = _encrypter_mock(encrypted={"project": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(cred_enc, MagicMock()), (prop_enc, MagicMock())], + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={"event": "push"}, + properties={"project": "demo"}, + credentials={"api_key": "plain"}, + ) + + # Assert + assert result["result"] == "success" + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(encrypted={"p": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.UNAUTHORIZED, + parameters={}, + properties={"p": "v"}, + credentials={}, + subscription_id="sub-fixed", + ) + + # Assert + assert result == {"result": "success", "id": "sub-fixed"} + + +def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ + mock_session.query.return_value = query_count + _mock_get_trigger_provider(mocker, provider_controller) + mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger") + + # Act + Assert + with pytest.raises(ValueError, match="Maximum number of providers"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + mock_logger.exception.assert_called_once() + + +def test_add_trigger_subscription_should_raise_error_when_name_exists( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_count, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="Credential name 'main' already exists"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + + +def test_update_trigger_subscription_should_raise_error_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query_sub + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1") + + +def test_update_trigger_subscription_should_raise_error_when_name_conflicts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + provider_id="langgenius/github/github", + credential_type=CredentialType.API_KEY.value, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_sub, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1", name="new-name") + + +def test_update_trigger_subscription_should_update_fields_and_clear_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + properties={"project": "enc-old"}, + parameters={"event": "old"}, + credentials={"api_key": "enc-old"}, + credential_type=CredentialType.API_KEY.value, + credential_expires_at=0, + expires_at=0, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_sub, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"}) + cred_enc = _encrypter_mock(encrypted={"api_key": "new-key"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(prop_enc, MagicMock()), (cred_enc, MagicMock())], + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.update_trigger_subscription( + tenant_id="tenant-1", + subscription_id="sub-1", + name="new", + properties={"project": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "new"}, + credentials={"api_key": "plain-key"}, + credential_expires_at=100, + expires_at=200, + ) + + # Assert + assert subscription.name == "new" + assert subscription.parameters == {"event": "new"} + assert subscription.credentials == {"api_key": "new-key"} + assert subscription.credential_expires_at == 100 + assert subscription.expires_at == 200 + mock_session.commit.assert_called_once() + mock_delete_cache.assert_called_once() + + +def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is None + + +def test_get_subscription_by_id_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"project": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + prop_enc = _encrypter_mock(decrypted={"project": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"project": "plain"} + + +def test_delete_trigger_provider_should_raise_error_when_subscription_missing( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + +def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.OAUTH2.value, + credentials={"token": "enc"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + side_effect=RuntimeError("remote fail"), + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + # Assert + mock_session.delete.assert_called_once_with(subscription) + mock_delete_cache.assert_called_once() + + +def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-2", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.UNAUTHORIZED.value, + credentials={}, + to_entity=lambda: SimpleNamespace(id="sub-2"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger") + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={}), MagicMock()), + ) + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-2") + + # Assert + mock_unsubscribe.assert_not_called() + mock_session.delete.assert_called_once_with(subscription) + + +def test_refresh_oauth_token_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + Assert + with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + provider_id=str(provider_id), + user_id="user-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"access_token": "enc"}, + credential_expires_at=0, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(cred_enc, cache), + ) + mocker.patch.object(TriggerProviderService, "get_oauth_client", return_value={"client_id": "id"}) + refreshed = SimpleNamespace(credentials={"access_token": "new"}, expires_at=12345) + oauth_handler = MagicMock() + oauth_handler.refresh_credentials.return_value = refreshed + mocker.patch("services.trigger.trigger_provider_service.OAuthHandler", return_value=oauth_handler) + + # Act + result = TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + # Assert + assert result == {"result": "success", "expires_at": 12345} + assert subscription.credentials == {"access_token": "new"} + assert subscription.credential_expires_at == 12345 + mock_session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_refresh_subscription_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + +def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + subscription = SimpleNamespace(expires_at=200) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "skipped", "expires_at": 200} + + +def test_refresh_subscription_should_refresh_and_persist_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + endpoint_id="endpoint-1", + expires_at=50, + provider_id=str(provider_id), + parameters={"event": "push"}, + properties={"p": "enc"}, + credentials={"c": "enc"}, + credential_type=CredentialType.API_KEY.value, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"c": "plain"}) + prop_cache = MagicMock() + prop_enc = _encrypter_mock(decrypted={"p": "plain"}, encrypted={"p": "new-enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, prop_cache), + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + provider_controller.refresh_trigger.return_value = SimpleNamespace(properties={"p": "new"}, expires_at=999) + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "success", "expires_at": 999} + assert subscription.properties == {"p": "new-enc"} + assert subscription.expires_at == 999 + mock_session.commit.assert_called_once() + prop_cache.delete.assert_called_once() + + +def test_get_oauth_client_should_return_tenant_client_when_available( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + system_client = None + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = tenant_client + mock_session.query.return_value = query_tenant + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "plain"} + + +def test_get_oauth_client_should_return_none_when_plugin_not_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result is None + + +def test_get_oauth_client_should_return_decrypted_system_client_when_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + return_value={"client_id": "system"}, + ) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "system"} + + +def test_get_oauth_client_should_raise_error_when_system_decryption_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + side_effect=RuntimeError("bad data"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Error decrypting system oauth params"): + TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + +def test_is_oauth_system_client_exists_should_return_false_when_unverified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is False + + +@pytest.mark.parametrize("has_client", [True, False]) +def test_is_oauth_system_client_exists_should_reflect_database_record( + has_client: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is has_client + + +def test_save_custom_oauth_client_params_should_return_success_when_nothing_to_update( + provider_id: TriggerProviderID, +) -> None: + # Arrange + # Act + result = TriggerProviderService.save_custom_oauth_client_params("tenant-1", provider_id, None, None) + + # Assert + assert result == {"result": "success"} + + +def test_save_custom_oauth_client_params_should_create_record_and_clear_params_when_client_params_none( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query + _mock_get_trigger_provider(mocker, provider_controller) + fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={}) + mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params=None, + enabled=True, + ) + + # Assert + assert result == {"result": "success"} + assert fake_model.encrypted_oauth_params == "{}" + assert fake_model.enabled is True + mock_session.add.assert_called_once_with(fake_model) + mock_session.commit.assert_called_once() + + +def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(enc, cache), + ) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params={"client_id": HIDDEN_VALUE, "client_secret": "new"}, + enabled=None, + ) + + # Assert + assert result == {"result": "success"} + assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"} + cache.delete.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {} + + +def test_get_custom_oauth_client_params_should_return_masked_decrypted_values( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "pl***id"} + + +def test_delete_custom_oauth_client_params_should_delete_record_and_commit( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.delete.return_value = 1 + + # Act + result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"result": "success"} + mock_session.commit.assert_called_once() + + +@pytest.mark.parametrize("exists", [True, False]) +def test_is_oauth_custom_client_enabled_should_return_expected_boolean( + exists: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None + + # Act + result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id) + + # Assert + assert result is exists + + +def test_get_subscription_by_endpoint_should_return_none_when_not_found( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is None + + +def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={"token": "plain"}), MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(_encrypter_mock(decrypted={"hook": "plain"}), MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"hook": "plain"} + + +def test_verify_subscription_credentials_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_api_key_validation_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + provider_controller.validate_credentials.side_effect = RuntimeError("bad credentials") + + # Act + Assert + with pytest.raises(ValueError, match="Invalid credentials: bad credentials"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + +def test_verify_subscription_credentials_should_return_verified_when_api_key_validation_succeeds( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + # Assert + assert result == {"verified": True} + + +def test_verify_subscription_credentials_should_return_verified_for_non_api_key_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.OAUTH2.value, credentials={}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + # Assert + assert result == {"verified": True} + + +def test_rebuild_trigger_subscription_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_for_unsupported_credential_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.UNAUTHORIZED.value) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + Assert + with pytest.raises(ValueError, match="not supported for auto creation"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=False, message="remote error"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to delete previous subscription"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_resubscribe_and_update_existing_subscription( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old-key"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + new_subscription = SimpleNamespace(properties={"project": "new"}, expires_at=888) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=True, message="ok"), + ) + mock_subscribe = mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.subscribe_trigger", + return_value=new_subscription, + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + mock_update = mocker.patch.object(TriggerProviderService, "update_trigger_subscription") + + # Act + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "push"}, + name="updated", + ) + + # Assert + call_kwargs = mock_subscribe.call_args.kwargs + assert call_kwargs["credentials"]["api_key"] == "old-key" + assert call_kwargs["credentials"]["region"] == "us" + mock_update.assert_called_once_with( + tenant_id="tenant-1", + subscription_id="sub-1", + name="updated", + parameters={"event": "push"}, + credentials={"api_key": "old-key", "region": "us"}, + properties={"project": "new"}, + expires_at=888, + ) diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 7b0103a2a1..16d3011810 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from services.vector_service import VectorService @@ -31,8 +32,8 @@ class _ParentDocStub: def _make_dataset( *, - indexing_technique: str = "high_quality", - doc_form: str = "text_model", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", is_multimodal: bool = False, @@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_called_once() args, kwargs = index_processor.load.call_args @@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) assert index_processor.load.call_count == 2 first_args, first_kwargs = index_processor.load.call_args_list[0] @@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_not_called() @@ -191,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider="openai", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -240,7 +241,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -328,7 +329,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) segment = _make_segment() dataset_document = MagicMock() @@ -347,7 +348,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = _make_segment() vector_instance = MagicMock() @@ -363,7 +364,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -379,7 +380,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -392,7 +393,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1") segment = _make_segment(segment_id="seg-1") dataset_document = MagicMock() @@ -439,7 +440,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX) segment = _make_segment() dataset_document = MagicMock() dataset_document.doc_language = "en" @@ -472,7 +473,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = MagicMock() child_chunk.content = "child" child_chunk.index_node_id = "id" @@ -488,7 +489,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) @@ -504,7 +505,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = MagicMock() new_chunk.content = "n" @@ -535,7 +536,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) VectorService.update_child_chunk_vector([], [], [], dataset) @@ -560,7 +561,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) vector_cls = MagicMock() @@ -574,7 +575,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) vector_cls = MagicMock() @@ -590,7 +591,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) vector_instance = MagicMock(name="vector_instance") @@ -611,7 +612,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -629,7 +630,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -662,7 +663,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -682,7 +683,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py index e2775ce90c..e973da7d56 100644 --- a/api/tests/unit_tests/services/test_website_service.py +++ b/api/tests/unit_tests/services/test_website_service.py @@ -443,7 +443,7 @@ def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monk def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None: firecrawl_instance = MagicMock() - firecrawl_instance.check_crawl_status.return_value = {"status": "completed"} + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 1, "current": 1, "data": []} monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) redis_mock = MagicMock() diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 57c0464dc6..d26c2f674f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -10,18 +10,36 @@ This test suite covers: """ import json +import uuid +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest -from dify_graph.enums import BuiltinNodeTypes +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + ErrorStrategy, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.variables.input_entities import VariableEntityType from libs.datetime_utils import naive_utc_now +from models.human_input import RecipientType from models.model import App, AppMode from models.workflow import Workflow, WorkflowType from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from services.workflow_service import WorkflowService +from services.workflow_service import ( + WorkflowService, + _rebuild_file_for_user_inputs_in_start_node, + _rebuild_single_file, + _setup_variable_pool, +) class TestWorkflowAssociatedDataFactory: @@ -544,6 +562,89 @@ class TestWorkflowService: conversation_variables=[], ) + def test_restore_published_workflow_to_draft_keeps_source_features_unmodified( + self, workflow_service, mock_db_session + ): + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + normalized_features = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + source_workflow = Workflow( + id="published-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version="2026-03-19T00:00:00", + graph=json.dumps(TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()), + features=json.dumps(legacy_features), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = Workflow( + id="draft-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + with ( + patch.object(workflow_service, "get_published_workflow_by_id", return_value=source_workflow), + patch.object(workflow_service, "get_draft_workflow", return_value=draft_workflow), + patch.object(workflow_service, "validate_graph_structure"), + patch.object(workflow_service, "validate_features_structure") as mock_validate_features, + patch("services.workflow_service.app_draft_workflow_was_synced"), + ): + result = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=source_workflow.id, + account=account, + ) + + mock_validate_features.assert_called_once_with(app_model=app, features=normalized_features) + assert result is draft_workflow + assert source_workflow.serialized_features == json.dumps(legacy_features) + assert draft_workflow.serialized_features == json.dumps(legacy_features) + mock_db_session.session.commit.assert_called_once() + # ==================== Workflow Validation Tests ==================== # These tests verify graph structure and feature configuration validation @@ -1226,3 +1327,1416 @@ class TestWorkflowService: with pytest.raises(ValueError, match="not supported convert to workflow"): workflow_service.convert_to_workflow(app, account, args) + + +# =========================================================================== +# TestWorkflowServiceCredentialValidation +# Tests for _validate_workflow_credentials and related private helpers +# =========================================================================== + + +class TestWorkflowServiceCredentialValidation: + """ + Tests for the private credential-validation helpers on WorkflowService. + + These helpers gate `publish_workflow` when `PluginManager` is enabled. + Each test focuses on a distinct branch inside `_validate_workflow_credentials`, + `_validate_llm_model_config`, `_check_default_tool_credential`, and the + load-balancing path. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + @staticmethod + def _make_workflow(nodes: list[dict]) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.tenant_id = "tenant-1" + wf.app_id = "app-1" + wf.graph_dict = {"nodes": nodes} + return wf + + # --- _validate_workflow_credentials: tool node (with credential_id) --- + + def test_validate_workflow_credentials_should_check_tool_credential_when_credential_id_present( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + "credential_id": "cred-123", + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + # Should not raise; mock allows the call + service._validate_workflow_credentials(workflow) + mock_check.assert_called_once() + + def test_validate_workflow_credentials_should_check_default_credential_when_no_credential_id( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + # No credential_id — should fall back to default + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + + # Assert + mock_default.assert_called_once_with("tenant-1", "my-provider") + + def test_validate_workflow_credentials_should_skip_tool_node_without_provider( + self, service: WorkflowService + ) -> None: + """Tool nodes without a provider_id should be silently skipped.""" + # Arrange + nodes = [{"id": "tool-node", "data": {"type": "tool"}}] + workflow = self._make_workflow(nodes) + + # Act + Assert (no error raised) + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + mock_default.assert_not_called() + + def test_validate_workflow_credentials_should_validate_llm_node_with_model_config( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_raise_for_llm_node_missing_model( + self, service: WorkflowService + ) -> None: + """LLM nodes without provider AND name should raise ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": {"type": "llm", "model": {"provider": "openai"}}, # name missing + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with pytest.raises(ValueError, match="Missing provider or model configuration"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_wrap_unexpected_exception_in_value_error( + self, service: WorkflowService + ) -> None: + """Non-ValueError exceptions from validation must be re-raised as ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch.object(service, "_validate_llm_model_config", side_effect=RuntimeError("boom")): + with pytest.raises(ValueError, match="boom"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_validate_agent_node_model(self, service: WorkflowService) -> None: + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {"provider": "openai", "model": "gpt-4"}}, + "tools": {"value": []}, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_validate_agent_tools(self, service: WorkflowService) -> None: + """Each agent tool with a provider should be checked for credential compliance.""" + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {}}, # no model config + "tools": { + "value": [ + {"provider_name": "provider-a", "credential_id": "cred-a"}, + {"provider_name": "provider-b"}, # uses default + ] + }, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check, + patch.object(service, "_check_default_tool_credential") as mock_default, + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_check.assert_called_once() # provider-a has credential_id + mock_default.assert_called_once_with("tenant-1", "provider-b") + + # --- _validate_llm_model_config --- + + def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: + """If ModelManager raises any exception it must be wrapped into ValueError.""" + # Arrange + with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")): + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + def test_validate_llm_model_config_success(self, service: WorkflowService) -> None: + """Test success path with ProviderManager and Model entities.""" + mock_model = MagicMock() + mock_model.model = "gpt-4" + mock_model.provider.provider = "openai" + + mock_configs = MagicMock() + mock_configs.get_models.return_value = [mock_model] + + with ( + patch("core.model_manager.ModelManager.get_model_instance"), + patch("core.provider_manager.ProviderManager") as mock_pm_cls, + ): + mock_pm_cls.return_value.get_configurations.return_value = mock_configs + + # Act + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # Assert + mock_model.raise_for_status.assert_called_once() + + def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: + """Test ValueError when model is not found in provider configurations.""" + mock_configs = MagicMock() + mock_configs.get_models.return_value = [] # No models + + with ( + patch("core.model_manager.ModelManager.get_model_instance"), + patch("core.provider_manager.ProviderManager") as mock_pm_cls, + ): + mock_pm_cls.return_value.get_configurations.return_value = mock_configs + + # Act + Assert + with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # --- _check_default_tool_credential --- + + def test_check_default_tool_credential_should_silently_pass_when_no_provider_found( + self, service: WorkflowService + ) -> None: + """Missing BuiltinToolProvider → plugin requires no credentials → no error.""" + # Arrange + with patch("services.workflow_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Act + Assert (should NOT raise) + service._check_default_tool_credential("tenant-1", "some-provider") + + def test_check_default_tool_credential_should_raise_when_compliance_fails(self, service: WorkflowService) -> None: + # Arrange + mock_provider = MagicMock() + mock_provider.id = "builtin-cred-id" + with ( + patch("services.workflow_service.db") as mock_db, + patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_provider + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate default credential"): + service._check_default_tool_credential("tenant-1", "some-provider") + + # --- _is_load_balancing_enabled --- + + def test_is_load_balancing_enabled_should_return_false_when_provider_not_found( + self, service: WorkflowService + ) -> None: + # Arrange + with patch("services.workflow_service.db"): + service_instance = WorkflowService() + + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_configs = MagicMock() + mock_configs.get.return_value = None # provider not found + mock_get_configs.return_value = mock_configs + + # Act + result = service_instance._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + def test_is_load_balancing_enabled_should_return_true_when_setting_enabled(self, service: WorkflowService) -> None: + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_provider_config = MagicMock() + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = True + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + mock_configs = MagicMock() + mock_configs.get.return_value = mock_provider_config + mock_get_configs.return_value = mock_configs + + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is True + + def test_is_load_balancing_enabled_should_return_false_on_exception(self, service: WorkflowService) -> None: + """Any exception should be swallowed and return False.""" + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations", side_effect=RuntimeError("db down")): + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + # --- _get_load_balancing_configs --- + + def test_get_load_balancing_configs_should_return_empty_list_on_exception(self, service: WorkflowService) -> None: + """Any exception during LB config retrieval should return an empty list.""" + # Arrange + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=RuntimeError("fail"), + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert + assert result == [] + + def test_get_load_balancing_configs_should_merge_predefined_and_custom(self, service: WorkflowService) -> None: + # Arrange + predefined = [{"credential_id": "cred-a"}, {"credential_id": None}] + custom = [{"credential_id": "cred-b"}] + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=[ + (None, predefined), # first call: predefined-model + (None, custom), # second call: custom-model + ], + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert — only entries with a credential_id should be returned + assert len(result) == 2 + assert all(c["credential_id"] for c in result) + + # --- _validate_load_balancing_credentials --- + + def test_validate_load_balancing_credentials_should_skip_when_no_model_config( + self, service: WorkflowService + ) -> None: + """Missing provider or model in node_data should be a no-op.""" + # Arrange + workflow = self._make_workflow([]) + node_data: dict = {} # no model key + + # Act + Assert (no error expected) + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_skip_when_lb_not_enabled( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + + # Act + Assert (no error expected) + with patch.object(service, "_is_load_balancing_enabled", return_value=False): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_raise_when_compliance_fails( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + lb_configs = [{"credential_id": "cred-lb-1"}] + + # Act + Assert + with ( + patch.object(service, "_is_load_balancing_enabled", return_value=True), + patch.object(service, "_get_load_balancing_configs", return_value=lb_configs), + patch( + "core.helper.credential_utils.check_credential_policy_compliance", + side_effect=Exception("policy violation"), + ), + ): + with pytest.raises(ValueError, match="Invalid load balancing credentials"): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + +# =========================================================================== +# TestWorkflowServiceExecutionHelpers +# Tests for _apply_error_strategy, _populate_execution_result, _execute_node_safely +# =========================================================================== + + +class TestWorkflowServiceExecutionHelpers: + """ + Tests for the private execution-result handling methods: + _apply_error_strategy, _populate_execution_result, _execute_node_safely. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + # --- _apply_error_strategy --- + + def test_apply_error_strategy_should_return_exception_status_noderunresult(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="something went wrong", + error_type="SomeError", + inputs={"x": 1}, + outputs={}, + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + assert result.error == "something went wrong" + assert result.metadata[WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY] == ErrorStrategy.FAIL_BRANCH + + def test_apply_error_strategy_should_include_default_values_for_default_value_strategy( + self, service: WorkflowService + ) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.DEFAULT_VALUE + node.default_value_dict = {"output_key": "fallback"} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="err", + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.outputs.get("output_key") == "fallback" + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + + # --- _populate_execution_result --- + + def test_populate_execution_result_should_set_succeeded_fields_when_run_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"q": "hello"}, + process_data={"steps": 3}, + outputs={"answer": "hi"}, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 10}, + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert node_execution.outputs == {"answer": "hi"} + assert node_execution.error is None # SUCCEEDED status doesn't set error + + def test_populate_execution_result_should_set_failed_status_and_error_when_not_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + + # Act + service._populate_execution_result(node_execution, None, False, "catastrophic failure") + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_execution.error == "catastrophic failure" + + def test_populate_execution_result_should_set_error_field_for_exception_status( + self, service: WorkflowService + ) -> None: + """A succeeded=True result with EXCEPTION status should still populate the error field.""" + # Arrange + node_execution = MagicMock() + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + error="constraint violated", + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.EXCEPTION + assert node_execution.error == "constraint violated" + + # --- _execute_node_safely --- + + def test_execute_node_safely_should_return_succeeded_result_on_happy_path(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_run_result.error = None + + succeeded_event = MagicMock(spec=NodeRunSucceededEvent) + succeeded_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield succeeded_event + + return node, _gen() + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert run_succeeded is True + assert error is None + + def test_execute_node_safely_should_return_failed_result_on_failed_event(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.FAILED + node_run_result.error = "node exploded" + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, _, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert run_succeeded is False + assert error == "node exploded" + + def test_execute_node_safely_should_handle_workflow_node_run_failed_error(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + exc = WorkflowNodeRunFailedError(node, "runtime failure") + + def invoke_fn(): + raise exc + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert out_result is None + assert run_succeeded is False + assert error == "runtime failure" + + def test_execute_node_safely_should_raise_when_no_result_event(self, service: WorkflowService) -> None: + """If the generator produces no NodeRunSucceededEvent/NodeRunFailedEvent, ValueError is expected.""" + # Arrange + node = MagicMock() + node.error_strategy = None + + def invoke_fn(): + def _gen(): + yield from [] + + return node, _gen() + + # Act + Assert + with pytest.raises(ValueError, match="no result returned"): + service._execute_node_safely(invoke_fn) + + # --- _apply_error_strategy with FAIL_BRANCH strategy --- + + def test_execute_node_safely_should_apply_error_strategy_on_failed_status(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + + original_result = MagicMock() + original_result.status = WorkflowNodeExecutionStatus.FAILED + original_result.error = "oops" + original_result.error_type = "ValueError" + original_result.inputs = {} + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = original_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, result, run_succeeded, _ = service._execute_node_safely(invoke_fn) + + # Assert — after applying error strategy status becomes EXCEPTION + assert result is not None + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + # run_succeeded should be True because EXCEPTION is in the succeeded set + assert run_succeeded is True + + +# =========================================================================== +# TestWorkflowServiceGetNodeLastRun +# Tests for get_node_last_run delegation to repository +# =========================================================================== + + +class TestWorkflowServiceGetNodeLastRun: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_node_last_run_should_delegate_to_repository(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "tenant-1" + app.id = "app-1" + workflow = MagicMock(spec=Workflow) + workflow.id = "wf-1" + expected = MagicMock() + + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = expected + + # Act + result = service.get_node_last_run(app, workflow, "node-42") + + # Assert + assert result is expected + service._node_execution_service_repo.get_node_last_execution.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="wf-1", + node_id="node-42", + ) + + def test_get_node_last_run_should_return_none_when_repository_returns_none(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "t" + app.id = "a" + workflow = MagicMock(spec=Workflow) + workflow.id = "w" + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = None + + # Act + result = service.get_node_last_run(app, workflow, "node-x") + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceModuleLevelHelpers +# Tests for module-level helper functions exported from workflow_service +# =========================================================================== + + +class TestSetupVariablePool: + """ + Tests for the module-level `_setup_variable_pool` function. + This helper initialises the VariablePool used for single-step workflow execution. + """ + + def _make_workflow(self, workflow_type: str = WorkflowType.WORKFLOW.value) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.app_id = "app-1" + wf.id = "wf-1" + wf.type = workflow_type + wf.environment_variables = [] + return wf + + def test_setup_variable_pool_should_use_full_system_variables_for_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with patch("services.workflow_service.VariablePool") as MockPool: + _setup_variable_pool( + query="hello", + files=[], + user_id="u-1", + user_inputs={"k": "v"}, + workflow=workflow, + node_type=BuiltinNodeTypes.START, + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — VariablePool should be called with a SystemVariable (non-default) + MockPool.assert_called_once() + call_kwargs = MockPool.call_args.kwargs + assert call_kwargs["user_inputs"] == {"k": "v"} + + def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.SystemVariable.default") as mock_default, + ): + _setup_variable_pool( + query="", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_type=BuiltinNodeTypes.LLM, # not a start/trigger node + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — SystemVariable.default() should be used for non-start nodes + mock_default.assert_called_once() + MockPool.assert_called_once() + + def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( + self, + ) -> None: + """For ADVANCED_CHAT workflows on a START node, query/conversation_id/dialogue_count should be set.""" + from models.workflow import WorkflowType + + # Arrange + workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) + + # Act + with patch("services.workflow_service.VariablePool") as MockPool: + _setup_variable_pool( + query="what is AI?", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_type=BuiltinNodeTypes.START, + conversation_id="conv-abc", + conversation_variables=[], + ) + + # Assert — we just verify VariablePool was called (chatflow path executed) + MockPool.assert_called_once() + + +class TestRebuildSingleFile: + """ + Tests for the module-level `_rebuild_single_file` function. + Ensures correct delegation to `build_from_mapping` / `build_from_mappings`. + """ + + def test_rebuild_single_file_should_call_build_from_mapping_for_file_type( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = {"url": "https://example.com/file.pdf", "type": "document"} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE) + + # Assert + assert result is mock_file + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) + + def test_rebuild_single_file_should_raise_when_file_value_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for file object"): + _rebuild_single_file("tenant-1", "not-a-dict", VariableEntityType.FILE) + + def test_rebuild_single_file_should_call_build_from_mappings_for_file_list( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = [{"url": "https://example.com/a.pdf"}, {"url": "https://example.com/b.pdf"}] + mock_files = [MagicMock(), MagicMock()] + + # Act + with patch("services.workflow_service.build_from_mappings", return_value=mock_files) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE_LIST) + + # Assert + assert result is mock_files + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) + + def test_rebuild_single_file_should_raise_when_file_list_value_not_list( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected list for file list object"): + _rebuild_single_file("tenant-1", "not-a-list", VariableEntityType.FILE_LIST) + + def test_rebuild_single_file_should_return_empty_list_for_empty_file_list( + self, + ) -> None: + # Arrange + Act + result = _rebuild_single_file("tenant-1", [], VariableEntityType.FILE_LIST) + + # Assert + assert result == [] + + def test_rebuild_single_file_should_raise_when_first_element_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for first element"): + _rebuild_single_file("tenant-1", ["not-a-dict"], VariableEntityType.FILE_LIST) + + +class TestRebuildFileForUserInputsInStartNode: + """ + Tests for the module-level `_rebuild_file_for_user_inputs_in_start_node` function. + """ + + def _make_start_node_data(self, variables: list) -> MagicMock: + start_data = MagicMock() + start_data.variables = variables + return start_data + + def _make_variable(self, name: str, var_type: VariableEntityType) -> MagicMock: + var = MagicMock() + var.variable = name + var.type = var_type + return var + + def test_rebuild_should_pass_through_non_file_variables( + self, + ) -> None: + # Arrange + text_var = self._make_variable("query", VariableEntityType.TEXT_INPUT) + start_data = self._make_start_node_data([text_var]) + user_inputs = {"query": "hello world"} + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — non-file inputs are untouched + assert result["query"] == "hello world" + + def test_rebuild_should_rebuild_file_variable( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + file_value = {"url": "https://example.com/file.pdf"} + user_inputs = {"attachment": file_value} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file): + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — the dict value should be replaced by the rebuilt File object + assert result["attachment"] is mock_file + + def test_rebuild_should_skip_variable_not_in_inputs( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + user_inputs: dict = {} # attachment not provided + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — no key should be added for missing inputs + assert "attachment" not in result + + +class TestWorkflowServiceResolveDeliveryMethod: + """ + Tests for the static helper `_resolve_human_input_delivery_method`. + """ + + def _make_method(self, method_id) -> MagicMock: + m = MagicMock() + m.id = method_id + return m + + def test_resolve_delivery_method_should_return_method_when_id_matches(self) -> None: + # Arrange + method_a = self._make_method("method-1") + method_b = self._make_method("method-2") + node_data = MagicMock() + node_data.delivery_methods = [method_a, method_b] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="method-2" + ) + + # Assert + assert result is method_b + + def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: + # Arrange + method_a = self._make_method("method-1") + node_data = MagicMock() + node_data.delivery_methods = [method_a] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="does-not-exist" + ) + + # Assert + assert result is None + + def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: + # Arrange + node_data = MagicMock() + node_data.delivery_methods = [] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="method-1" + ) + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceDraftExecution +# Tests for run_draft_workflow_node +# =========================================================================== + + +class TestWorkflowServiceDraftExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_draft_workflow_node_should_execute_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.id = "app-1" + app.tenant_id = "tenant-1" + account = MagicMock() + account.id = "user-1" + + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.id = "wf-1" + draft_workflow.tenant_id = "tenant-1" + draft_workflow.app_id = "app-1" + draft_workflow.graph_dict = {"nodes": []} + + node_id = "start-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.START)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + # Mocking complex dependencies + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.StartNodeData") as mock_start_data, + patch( + "services.workflow_service._rebuild_file_for_user_inputs_in_start_node", + side_effect=lambda **kwargs: kwargs["user_inputs"], + ), + patch("services.workflow_service._setup_variable_pool"), + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory") as mock_repo_factory, + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.START + mock_node.title = "Start Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="start-node", + node_type=BuiltinNodeTypes.START, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + mock_repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = mock_repo + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "start" + mock_execution_record.node_id = "start-node" + mock_execution_record.load_full_outputs.return_value = {} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + result = service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={"key": "val"}, + query="hi", + files=[], + ) + + # Assert + assert result is not None + mock_run.assert_called_once() + mock_repo.save.assert_called_once() + mock_saver_cls.return_value.save.assert_called_once() + + def test_run_draft_workflow_node_should_execute_non_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + account = MagicMock() + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.graph_dict = {"nodes": []} + node_id = "llm-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.LLM)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory"), + patch("services.workflow_service.DraftVariableSaver"), + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.LLM + mock_node.title = "LLM Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "llm" + mock_execution_record.node_id = "llm-node" + mock_execution_record.load_full_outputs.return_value = {"answer": "hello"} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={}, + query="", + files=None, + ) + + # Assert + # For non-start nodes, VariablePool should be initialized with environment_variables + mock_pool_cls.assert_called_once() + args, kwargs = mock_pool_cls.call_args + assert "environment_variables" in kwargs + + +# =========================================================================== +# TestWorkflowServiceHumanInputOperations +# Tests for Human Input related methods +# =========================================================================== + + +class TestWorkflowServiceHumanInputOperations: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_human_input_form_preview_should_raise_if_workflow_not_init(self, service: WorkflowService) -> None: + service.get_draft_workflow = MagicMock(return_value=None) + with pytest.raises(ValueError, match="Workflow not initialized"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_should_raise_if_wrong_node_type(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "llm"}} + service.get_draft_workflow = MagicMock(return_value=draft) + with patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.LLM): + with pytest.raises(ValueError, match="Node type must be human-input"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = { + "id": "node-1", + "data": MagicMock(type=BuiltinNodeTypes.HUMAN_INPUT), + } + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.render_form_content_before_submission.return_value = "rendered" + mock_node.resolve_default_values.return_value = {"def": 1} + mock_node.title = "Form Title" + mock_node.node_data = MagicMock() + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.HumanInputRequired") as mock_required_cls, + ): + service.get_human_input_form_preview(app_model=app_model, account=account, node_id="node-1") + mock_node.render_form_content_before_submission.assert_called_once() + mock_required_cls.return_value.model_dump.assert_called_once() + + def test_submit_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.node_data = MagicMock() + mock_node.node_data.outputs_field_names.return_value = ["field1"] + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.validate_human_input_submission"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + ): + result = service.submit_human_input_form_preview( + app_model=app_model, account=account, node_id="node-1", form_inputs={"field1": "val1"}, action="submit" + ) + assert result["__action_id"] == "submit" + mock_saver_cls.return_value.save.assert_called_once() + + def test_test_human_input_delivery_success(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, + patch("services.workflow_service.apply_debug_email_recipient"), + patch.object(service, "_build_human_input_variable_pool"), + patch.object(service, "_build_human_input_node"), + patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), + patch("services.workflow_service.HumanInputDeliveryTestService") as mock_test_srv, + ): + mock_resolve.return_value = MagicMock() + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="method-1" + ) + mock_test_srv.return_value.send_test.assert_called_once() + + def test_test_human_input_delivery_failure_cases(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method", return_value=None), + ): + with pytest.raises(ValueError, match="Delivery method not found"): + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="none" + ) + + def test_load_email_recipients_parsing_failure(self, service: WorkflowService) -> None: + # Arrange + mock_recipient = MagicMock() + mock_recipient.recipient_payload = "invalid json" + mock_recipient.recipient_type = RecipientType.EMAIL_MEMBER + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.Session") as mock_session_cls, + patch("services.workflow_service.select"), + patch("services.workflow_service.json.loads", side_effect=ValueError("bad json")), + ): + mock_session = mock_session_cls.return_value.__enter__.return_value + # sqlalchemy assertions check for .bind + mock_session.bind = MagicMock() # removed spec=Engine to avoid import issues for now + mock_session.scalars.return_value.all.return_value = [mock_recipient] + + # Act + # _load_email_recipients(form_id: str) is a static method + result = WorkflowService._load_email_recipients("form-1") + + # Assert + assert result == [] # Should fall back to empty list on parsing error + + def test_build_human_input_variable_pool(self, service: WorkflowService) -> None: + workflow = MagicMock() + workflow.environment_variables = [] + workflow.graph_dict = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.HumanInputNode.extract_variable_selector_to_variable_mapping"), + patch("services.workflow_service.load_into_variable_pool"), + patch("services.workflow_service.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + service._build_human_input_variable_pool( + app_model=MagicMock(), workflow=workflow, node_config={}, manual_inputs={}, user_id="user-1" + ) + mock_pool_cls.assert_called_once() + + +# =========================================================================== +# TestWorkflowServiceFreeNodeExecution +# Tests for run_free_workflow_node and handle_single_step_result +# =========================================================================== + + +class TestWorkflowServiceFreeNodeExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_free_workflow_node_success(self, service: WorkflowService) -> None: + node_execution = MagicMock() + with ( + patch.object(service, "_handle_single_step_result", return_value=node_execution), + patch("services.workflow_service.WorkflowEntry.run_free_node"), + ): + result = service.run_free_workflow_node({}, "tenant-1", "user-1", "node-1", {}) + assert result == node_execution + + def test_validate_graph_structure_coexist_error(self, service: WorkflowService) -> None: + graph = { + "nodes": [ + {"data": {"type": "start"}}, + {"data": {"type": "trigger-webhook"}}, # is_trigger_node=True + ] + } + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + service.validate_graph_structure(graph) + + def test_validate_features_structure_success(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "workflow" + features = {} + with patch("services.workflow_service.WorkflowAppConfigManager.config_validate") as mock_val: + service.validate_features_structure(app, features) + mock_val.assert_called_once() + + def test_validate_features_structure_invalid_mode(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "invalid" + with pytest.raises(ValueError, match="Invalid app mode"): + service.validate_features_structure(app, {}) + + def test_validate_human_input_node_data_error(self, service: WorkflowService) -> None: + with patch( + "dify_graph.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") + ): + with pytest.raises(ValueError, match="Invalid HumanInput node data"): + service._validate_human_input_node_data({}) + + def test_rebuild_single_file_unreachable(self) -> None: + # Test line 1523 (unreachable) + with pytest.raises(Exception, match="unreachable"): + _rebuild_single_file("tenant-1", {}, cast(Any, "invalid_type")) + + def test_build_human_input_node(self, service: WorkflowService) -> None: + """Cover _build_human_input_node (lines 1065-1088).""" + workflow = MagicMock() + workflow.id = "wf-1" + workflow.tenant_id = "t-1" + workflow.app_id = "app-1" + account = MagicMock() + account.id = "u-1" + node_config = {"id": "n-1"} + variable_pool = MagicMock() + + with ( + patch("services.workflow_service.GraphInitParams"), + patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.HumanInputNode") as mock_node_cls, + patch("services.workflow_service.HumanInputFormRepositoryImpl"), + ): + node = service._build_human_input_node( + workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool + ) + assert node == mock_node_cls.return_value + mock_node_cls.assert_called_once() diff --git a/api/tests/unit_tests/services/test_workspace_service.py b/api/tests/unit_tests/services/test_workspace_service.py new file mode 100644 index 0000000000..9bfd7eb2c5 --- /dev/null +++ b/api/tests/unit_tests/services/test_workspace_service.py @@ -0,0 +1,576 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from models.account import Tenant + +# --------------------------------------------------------------------------- +# Constants used throughout the tests +# --------------------------------------------------------------------------- + +TENANT_ID = "tenant-abc" +ACCOUNT_ID = "account-xyz" +FILES_BASE_URL = "https://files.example.com" + +DB_PATH = "services.workspace_service.db" +FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features" +TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles" +DIFY_CONFIG_PATH = "services.workspace_service.dify_config" +CURRENT_USER_PATH = "services.workspace_service.current_user" +CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool" + + +# --------------------------------------------------------------------------- +# Helpers / factories +# --------------------------------------------------------------------------- + + +def _make_tenant( + tenant_id: str = TENANT_ID, + name: str = "My Workspace", + plan: str = "sandbox", + status: str = "active", + custom_config: dict | None = None, +) -> Tenant: + """Create a minimal Tenant-like namespace.""" + return cast( + Tenant, + SimpleNamespace( + id=tenant_id, + name=name, + plan=plan, + status=status, + created_at="2024-01-01T00:00:00Z", + custom_config_dict=custom_config or {}, + ), + ) + + +def _make_feature( + can_replace_logo: bool = False, + next_credit_reset_date: str | None = None, + billing_plan: str = "sandbox", +) -> MagicMock: + """Create a feature namespace matching what FeatureService.get_features returns.""" + feature = MagicMock() + feature.can_replace_logo = can_replace_logo + feature.next_credit_reset_date = next_credit_reset_date + feature.billing.subscription.plan = billing_plan + return feature + + +def _make_pool(quota_limit: int, quota_used: int) -> MagicMock: + pool = MagicMock() + pool.quota_limit = quota_limit + pool.quota_used = quota_used + return pool + + +def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace: + return SimpleNamespace(role=role) + + +def _tenant_info(result: object) -> dict[str, Any] | None: + return cast(dict[str, Any] | None, result) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_current_user() -> SimpleNamespace: + """Return a lightweight current_user stand-in.""" + return SimpleNamespace(id=ACCOUNT_ID) + + +@pytest.fixture +def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: + """ + Patch the common external boundaries used by WorkspaceService.get_tenant_info. + + Returns a dict of named mocks so individual tests can customise them. + """ + mocker.patch(CURRENT_USER_PATH, mock_current_user) + + mock_db_session = mocker.patch(f"{DB_PATH}.session") + mock_query_chain = MagicMock() + mock_db_session.query.return_value = mock_query_chain + mock_query_chain.where.return_value = mock_query_chain + mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") + + mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature()) + mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False) + mock_config = mocker.patch(DIFY_CONFIG_PATH) + mock_config.EDITION = "SELF_HOSTED" + mock_config.FILES_URL = FILES_BASE_URL + + return { + "db_session": mock_db_session, + "query_chain": mock_query_chain, + "get_features": mock_feature, + "has_roles": mock_has_roles, + "config": mock_config, + } + + +# --------------------------------------------------------------------------- +# 1. None Tenant Handling +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None: + """get_tenant_info should short-circuit and return None for a falsy tenant.""" + from services.workspace_service import WorkspaceService + + # Arrange + tenant = None + + # Act + result = WorkspaceService.get_tenant_info(cast(Tenant, tenant)) + + # Assert + assert result is None + + +def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None: + """get_tenant_info treats any falsy value as absent (e.g. empty string, 0).""" + from services.workspace_service import WorkspaceService + + # Arrange / Act / Assert + assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# 2. Basic Tenant Info — happy path +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_return_base_fields( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """get_tenant_info should always return the six base scalar fields.""" + from services.workspace_service import WorkspaceService + + # Arrange + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["id"] == TENANT_ID + assert result["name"] == "My Workspace" + assert result["plan"] == "sandbox" + assert result["status"] == "active" + assert result["created_at"] == "2024-01-01T00:00:00Z" + assert result["trial_end_reason"] is None + + +def test_get_tenant_info_should_populate_role_from_tenant_account_join( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """The 'role' field should be taken from TenantAccountJoin, not the default.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin") + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["role"] == "admin" + + +def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """ + The service asserts that TenantAccountJoin exists. + Missing join should raise AssertionError. + """ + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["query_chain"].first.return_value = None + tenant = _make_tenant() + + # Act + Assert + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + +# --------------------------------------------------------------------------- +# 3. Logo Customisation +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config block should appear for OWNER/ADMIN when can_replace_logo is True.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant( + custom_config={ + "replace_webapp_logo": True, + "remove_webapp_brand": True, + } + ) + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" in result + assert result["custom_config"]["remove_webapp_brand"] is True + expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo" + assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url + + +def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + +def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config should be absent when can_replace_logo is False.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" not in result + + +def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config block is gated on OWNER or ADMIN role.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = False # regular member + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" not in result + + +def test_get_tenant_info_should_use_files_url_for_logo_url( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """The logo URL should use dify_config.FILES_URL as the base.""" + from services.workspace_service import WorkspaceService + + # Arrange + custom_base = "https://cdn.mycompany.io" + basic_mocks["config"].FILES_URL = custom_base + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant(custom_config={"replace_webapp_logo": True}) + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + +# --------------------------------------------------------------------------- +# 4. Cloud-Edition Credit Features +# --------------------------------------------------------------------------- + +CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX + + +@pytest.fixture +def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: + """Patches for CLOUD edition tests, billing plan = professional by default.""" + mocker.patch(CURRENT_USER_PATH, mock_current_user) + + mock_db_session = mocker.patch(f"{DB_PATH}.session") + mock_query_chain = MagicMock() + mock_db_session.query.return_value = mock_query_chain + mock_query_chain.where.return_value = mock_query_chain + mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") + + mock_feature = mocker.patch( + FEATURE_SERVICE_PATH, + return_value=_make_feature( + can_replace_logo=False, + next_credit_reset_date="2025-02-01", + billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX, + ), + ) + mocker.patch(TENANT_SERVICE_PATH, return_value=False) + mock_config = mocker.patch(DIFY_CONFIG_PATH) + mock_config.EDITION = "CLOUD" + mock_config.FILES_URL = FILES_BASE_URL + + return { + "db_session": mock_db_session, + "query_chain": mock_query_chain, + "get_features": mock_feature, + "config": mock_config, + } + + +def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """next_credit_reset_date should be present in CLOUD edition.""" + from services.workspace_service import WorkspaceService + + # Arrange + mocker.patch( + CREDIT_POOL_SERVICE_PATH, + side_effect=[None, None], # both paid and trial pools absent + ) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + +def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """trial_credits/trial_credits_used come from the paid pool when conditions are met.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=1000, quota_used=200) + mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + +def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """quota_limit == -1 means unlimited; service should still use the paid pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=-1, quota_used=999) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When paid pool is exhausted (used >= limit), switch to trial pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full + trial_pool = _make_pool(quota_limit=100, quota_used=10) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When paid_pool is None, fall back to trial pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + trial_pool = _make_pool(quota_limit=50, quota_used=5) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """ + When the subscription plan IS SANDBOX, the paid pool branch is skipped + entirely and we fall back to the trial pool. + """ + from enums.cloud_plan import CloudPlan + from services.workspace_service import WorkspaceService + + # Arrange — override billing plan to SANDBOX + cloud_mocks["get_features"].return_value = _make_feature( + next_credit_reset_date="2025-02-01", + billing_plan=CloudPlan.SANDBOX, + ) + paid_pool = _make_pool(quota_limit=1000, quota_used=0) + trial_pool = _make_pool(quota_limit=200, quota_used=20) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + +def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When both paid and trial pools are absent, trial_credits should not be set.""" + from services.workspace_service import WorkspaceService + + # Arrange + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + +# --------------------------------------------------------------------------- +# 5. Self-hosted / Non-Cloud Edition +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + from services.workspace_service import WorkspaceService + + # Arrange (basic_mocks already sets EDITION = "SELF_HOSTED") + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + +# --------------------------------------------------------------------------- +# 6. DB query integrity +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """ + The DB query for TenantAccountJoin must be scoped to the correct + tenant_id and current_user.id. + """ + from services.workspace_service import WorkspaceService + + # Arrange + tenant = _make_tenant(tenant_id="my-special-tenant") + mock_current_user = mocker.patch(CURRENT_USER_PATH) + mock_current_user.id = "special-user-id" + + # Act + WorkspaceService.get_tenant_info(tenant) + + # Assert — db.session.query was invoked (at least once) + basic_mocks["db_session"].query.assert_called() diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py deleted file mode 100644 index 9616d2f102..0000000000 --- a/api/tests/unit_tests/services/tools/test_tools_transform_service.py +++ /dev/null @@ -1,452 +0,0 @@ -from unittest.mock import Mock - -from core.tools.__base.tool import Tool -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType -from services.tools.tools_transform_service import ToolTransformService - - -class TestToolTransformService: - """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method""" - - def test_convert_tool_with_parameter_override(self): - """Test that runtime parameters correctly override base parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - # Create mock runtime parameters that override base parameters - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" # Different label to verify override - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.author == "test_author" - assert result.name == "test_tool" - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Find the overridden parameter - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" # Should be runtime version - - # Find the non-overridden parameter - original_param = next((p for p in result.parameters if p.name == "param2"), None) - assert original_param is not None - assert original_param.label == "Base Param 2" # Should be base version - - def test_convert_tool_with_additional_runtime_parameters(self): - """Test that additional runtime parameters are added to the final list""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters - one that overrides and one that's new - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "runtime_only" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Only Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Check that both parameters are present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "runtime_only" in param_names - - # Verify the overridden parameter has runtime version - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" - - # Verify the new runtime parameter is included - new_param = next((p for p in result.parameters if p.name == "runtime_only"), None) - assert new_param is not None - assert new_param.label == "Runtime Only Param" - - def test_convert_tool_with_non_form_runtime_parameters(self): - """Test that non-FORM runtime parameters are not added as new parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters with different forms - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "llm_param" - runtime_param2.form = ToolParameter.ToolParameterForm.LLM - runtime_param2.type = "string" - runtime_param2.label = "LLM Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 1 # Only the FORM parameter should be present - - # Check that only the FORM parameter is present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "llm_param" not in param_names - - def test_convert_tool_with_empty_parameters(self): - """Test conversion with empty base and runtime parameters""" - # Create mock tool with no parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_with_none_parameters(self): - """Test conversion when base parameters is None""" - # Create mock tool with None parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = None - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_parameter_order_preserved(self): - """Test that parameter order is preserved correctly""" - # Create mock base parameters in specific order - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - base_param3 = Mock(spec=ToolParameter) - base_param3.name = "param3" - base_param3.form = ToolParameter.ToolParameterForm.FORM - base_param3.type = "string" - base_param3.label = "Base Param 3" - - # Create runtime parameter that overrides middle parameter - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "param2" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Param 2" - - # Create new runtime parameter - runtime_param4 = Mock(spec=ToolParameter) - runtime_param4.name = "param4" - runtime_param4.form = ToolParameter.ToolParameterForm.FORM - runtime_param4.type = "string" - runtime_param4.label = "Runtime Param 4" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2, base_param3] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 4 - - # Check that order is maintained: base parameters first, then new runtime parameters - param_names = [p.name for p in result.parameters] - assert param_names == ["param1", "param2", "param3", "param4"] - - # Verify that param2 was overridden with runtime version - param2 = result.parameters[1] - assert param2.name == "param2" - assert param2.label == "Runtime Param 2" - - -class TestWorkflowProviderToUserProvider: - """Test cases for ToolTransformService.workflow_provider_to_user_provider method""" - - def test_workflow_provider_to_user_provider_with_workflow_app_id(self): - """Test that workflow_provider_to_user_provider correctly sets workflow_app_id.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - workflow_app_id = "app_123" - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1", "label2"], - workflow_app_id=workflow_app_id, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "test_author" - assert result.name == "test_workflow_tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["label1", "label2"] - assert result.is_team_authorization is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] - - def test_workflow_provider_to_user_provider_without_workflow_app_id(self): - """Test that workflow_provider_to_user_provider works when workflow_app_id is not provided.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method without workflow_app_id - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1"], - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == ["label1"] - - def test_workflow_provider_to_user_provider_workflow_app_id_none(self): - """Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method with explicit None values - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=None, - workflow_app_id=None, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == [] - - def test_workflow_provider_to_user_provider_preserves_other_fields(self): - """Test that workflow_provider_to_user_provider preserves all other entity fields.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller with various fields - workflow_app_id = "app_456" - provider_id = "provider_456" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "another_author" - mock_controller.entity.identity.name = "another_workflow_tool" - mock_controller.entity.identity.description = I18nObject( - en_US="Another description", zh_Hans="Another description" - ) - mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"} - mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.label = I18nObject( - en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool" - ) - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["automation", "workflow"], - workflow_app_id=workflow_app_id, - ) - - # Verify all fields are preserved correctly - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "another_author" - assert result.name == "another_workflow_tool" - assert result.description.en_US == "Another description" - assert result.description.zh_Hans == "Another description" - assert result.icon == {"type": "emoji", "content": "⚙️"} - assert result.icon_dark == {"type": "emoji", "content": "🔧"} - assert result.label.en_US == "Another Workflow Tool" - assert result.label.zh_Hans == "Another Workflow Tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["automation", "workflow"] - assert result.masked_credentials == {} - assert result.is_team_authorization is True - assert result.allow_delete is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index ae59da0a3d..0000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,162 +0,0 @@ -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service - - -class DummyWorkflow: - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - def __init__(self, result): - self._result = result - - def where(self, *args, **kwargs): - return self - - def first(self): - return self._result - - -class DummySession: - def __init__(self) -> None: - self.added: list[object] = [] - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - def add(self, obj) -> None: - self.added.append(obj) - - def begin(self): - return DummyBegin(self) - - -class DummyBegin: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionContext: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionFactory: - def __init__(self, session: DummySession) -> None: - self._session = session - - def create_session(self) -> DummySessionContext: - return DummySessionContext(self._session) - - -def _build_fake_session(app) -> SimpleNamespace: - def query(model): - if model is WorkflowToolProvider: - return FakeQuery(None) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - return SimpleNamespace(query=query) - - -def _build_parameters() -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM), - ] - - -def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_session = _build_fake_session(app) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - mock_invalidate = MagicMock() - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "tool"}, - description="desc", - parameters=_build_parameters(), - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - mock_invalidate.assert_not_called() - - -def test_create_workflow_tool_success(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_db = MagicMock() - fake_session = _build_fake_session(app) - fake_db.session = fake_session - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - icon = {"type": "emoji", "emoji": "tool"} - - result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) - - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created_provider = dummy_session.added[0] - assert created_provider.name == "tool_name" - assert created_provider.label == "Tool" - assert created_provider.icon == json.dumps(icon) - assert created_provider.version == workflow.version - mock_from_db.assert_called_once() diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index c99275c6b2..33a5607ef4 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,6 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -151,8 +152,8 @@ class VectorServiceTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", - indexing_technique: str = "high_quality", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", index_struct_dict: dict | None = None, @@ -493,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -505,7 +506,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model") + VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_called_once() @@ -534,7 +535,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -567,7 +568,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -590,7 +591,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -615,7 +616,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="economy" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -649,7 +650,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_not_called() @@ -668,7 +669,7 @@ class TestVectorService: store when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -694,7 +695,7 @@ class TestVectorService: index when using economy indexing with keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -730,7 +731,7 @@ class TestVectorService: index when using economy indexing without keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -894,7 +895,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -922,7 +923,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -950,7 +951,7 @@ class TestVectorService: when there are new chunks, updated chunks, and deleted chunks. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") @@ -992,7 +993,7 @@ class TestVectorService: add_texts is called, not delete_by_ids. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1018,7 +1019,7 @@ class TestVectorService: delete_by_ids is called, not add_texts. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1044,7 +1045,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1074,7 +1075,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1098,7 +1099,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py deleted file mode 100644 index dfe325648d..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ /dev/null @@ -1,127 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import Session - -from models.model import App -from models.workflow import Workflow -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService - - -@pytest.fixture -def workflow_setup(): - mock_session_maker = MagicMock() - workflow_service = WorkflowService(mock_session_maker) - session = MagicMock(spec=Session) - tenant_id = "test-tenant-id" - workflow_id = "test-workflow-id" - - # Mock workflow - workflow = MagicMock(spec=Workflow) - workflow.id = workflow_id - workflow.tenant_id = tenant_id - workflow.version = "1.0" # Not a draft - workflow.tool_published = False # Not published as a tool by default - - # Mock app - app = MagicMock(spec=App) - app.id = "test-app-id" - app.name = "Test App" - app.workflow_id = None # Not used by an app by default - - return { - "workflow_service": workflow_service, - "session": session, - "tenant_id": tenant_id, - "workflow_id": workflow_id, - "workflow": workflow, - "app": app, - } - - -def test_delete_workflow_success(workflow_setup): - # Setup mocks - - # Mock the tool provider query to return None (not published as a tool) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = None - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method - result = workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - assert result is True - workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"]) - - -def test_delete_workflow_draft_error(workflow_setup): - # Setup mocks - workflow_setup["workflow"].version = "draft" - workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"]) - - # Call the method and verify exception - with pytest.raises(DraftWorkflowDeletionError): - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_in_use_by_app_error(workflow_setup): - # Setup mocks - workflow_setup["app"].workflow_id = workflow_setup["workflow_id"] - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], workflow_setup["app"]] - ) # Return workflow first, then app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message contains app name - assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_published_as_tool_error(workflow_setup): - # Setup mocks - from models.tools import WorkflowToolProvider - - # Mock the tool provider query - mock_tool_provider = MagicMock(spec=WorkflowToolProvider) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message - assert "Cannot delete workflow that is published as a tool" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py deleted file mode 100644 index 79bf5e94c2..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from repositories.sqlalchemy_api_workflow_node_execution_repository import ( - DifyAPISQLAlchemyWorkflowNodeExecutionRepository, -) - - -class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: - @pytest.fixture - def repository(self): - mock_session_maker = MagicMock() - return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) - - def test_repository_implements_protocol(self, repository): - """Test that the repository implements the required protocol methods.""" - # Verify all protocol methods are implemented - assert hasattr(repository, "get_node_last_execution") - assert hasattr(repository, "get_executions_by_workflow_run") - assert hasattr(repository, "get_execution_by_id") - - # Verify methods are callable - assert callable(repository.get_node_last_execution) - assert callable(repository.get_executions_by_workflow_run) - assert callable(repository.get_execution_by_id) - assert callable(repository.delete_expired_executions) - assert callable(repository.delete_executions_by_app) - assert callable(repository.get_expired_executions_batch) - assert callable(repository.delete_executions_by_ids) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_restore.py b/api/tests/unit_tests/services/workflow/test_workflow_restore.py new file mode 100644 index 0000000000..179361de45 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_restore.py @@ -0,0 +1,77 @@ +import json +from types import SimpleNamespace + +from models.workflow import Workflow +from services.workflow_restore import apply_published_workflow_snapshot_to_draft + +LEGACY_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NORMALIZED_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def _create_workflow(*, workflow_id: str, version: str, features: dict[str, object]) -> Workflow: + return Workflow( + id=workflow_id, + tenant_id="tenant-id", + app_id="app-id", + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(features), + created_by="account-id", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def test_apply_published_workflow_snapshot_to_draft_copies_serialized_features_without_mutating_source() -> None: + source_workflow = _create_workflow( + workflow_id="published-workflow-id", + version="2026-03-19T00:00:00", + features=LEGACY_FEATURES, + ) + + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id="tenant-id", + app_id="app-id", + source_workflow=source_workflow, + draft_workflow=None, + account=SimpleNamespace(id="account-id"), + updated_at_factory=lambda: source_workflow.updated_at, + ) + + assert is_new_draft is True + assert source_workflow.serialized_features == json.dumps(LEGACY_FEATURES) + assert source_workflow.normalized_features_dict == NORMALIZED_FEATURES + assert draft_workflow.serialized_features == json.dumps(LEGACY_FEATURES) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 74ba7f9c34..936a10d6c5 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -183,10 +184,10 @@ class TestErrorHandling: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -228,10 +229,10 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=pipeline_id, ) @@ -264,10 +265,10 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=None, ) @@ -320,10 +321,10 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -365,10 +366,10 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert - storage delete was attempted @@ -407,10 +408,10 @@ class TestEdgeCases: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -444,7 +445,7 @@ class TestIndexProcessorParameters: - Dataset object with correct attributes is passed """ # Arrange - indexing_technique = "high_quality" + indexing_technique = IndexTechniqueType.HIGH_QUALITY index_struct = '{"type": "paragraph"}' # Act @@ -454,7 +455,7 @@ class TestIndexProcessorParameters: indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 8a721124d6..0b189ebae2 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -58,6 +59,11 @@ def mock_redis(): # Redis is already mocked globally in conftest.py # Reset it for each test redis_client.reset_mock() + redis_client.get.reset_mock() + redis_client.setex.reset_mock() + redis_client.delete.reset_mock() + redis_client.lpush.reset_mock() + redis_client.rpop.reset_mock() redis_client.get.return_value = None redis_client.setex.return_value = True redis_client.delete.return_value = True @@ -203,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id): dataset = Mock(spec=Dataset) dataset.id = dataset_id dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -222,7 +228,7 @@ def mock_documents(document_ids, dataset_id): doc.stopped_at = None doc.processing_started_at = None # optional attribute used in some code paths - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX documents.append(doc) return documents diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 3668416e36..f49f4535af 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, document.tenant_id = str(uuid.uuid4()) document.data_source_type = "notion_import" document.indexing_status = "completed" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX document.data_source_info_dict = { "notion_workspace_id": notion_workspace_id, "notion_page_id": notion_page_id, diff --git a/api/uv.lock b/api/uv.lock index ddb70f6b54..47a3c45df0 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -169,12 +169,6 @@ version = "1.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a0/87/1d7019d23891897cb076b2f7e3c81ab3c2ba91de3bb067196f675d60d34c/alibabacloud-credentials-api-1.0.0.tar.gz", hash = "sha256:8c340038d904f0218d7214a8f4088c31912bfcf279af2cbc7d9be4897a97dd2f", size = 2330, upload-time = "2025-01-13T05:53:04.931Z" } -[[package]] -name = "alibabacloud-endpoint-util" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/92/7d/8cc92a95c920e344835b005af6ea45a0db98763ad6ad19299d26892e6c8d/alibabacloud_endpoint_util-0.0.4.tar.gz", hash = "sha256:a593eb8ddd8168d5dc2216cd33111b144f9189fcd6e9ca20e48f358a739bbf90", size = 2813, upload-time = "2025-06-12T07:20:52.572Z" } - [[package]] name = "alibabacloud-gateway-spi" version = "0.0.3" @@ -186,69 +180,17 @@ sdist = { url = "https://files.pythonhosted.org/packages/ab/98/d7111245f17935bf7 [[package]] name = "alibabacloud-gpdb20160503" -version = "3.8.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-endpoint-util" }, - { name = "alibabacloud-openapi-util" }, - { name = "alibabacloud-openplatform20191219" }, - { name = "alibabacloud-oss-sdk" }, - { name = "alibabacloud-oss-util" }, - { name = "alibabacloud-tea-fileform" }, - { name = "alibabacloud-tea-openapi" }, - { name = "alibabacloud-tea-util" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/15/6a/cc72e744e95c8f37fa6a84e66ae0b9b57a13ee97a0ef03d94c7127c31d75/alibabacloud_gpdb20160503-3.8.3.tar.gz", hash = "sha256:4dfcc0d9cff5a921d529d76f4bf97e2ceb9dc2fa53f00ab055f08509423d8e30", size = 155092, upload-time = "2024-07-18T17:09:42.438Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/36/bce41704b3bf59d607590ec73a42a254c5dea27c0f707aee11d20512a200/alibabacloud_gpdb20160503-3.8.3-py3-none-any.whl", hash = "sha256:06e1c46ce5e4e9d1bcae76e76e51034196c625799d06b2efec8d46a7df323fe8", size = 156097, upload-time = "2024-07-18T17:09:40.414Z" }, -] - -[[package]] -name = "alibabacloud-openapi-util" -version = "0.2.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea-util" }, - { name = "cryptography" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f6/50/5f41ab550d7874c623f6e992758429802c4b52a6804db437017e5387de33/alibabacloud_openapi_util-0.2.2.tar.gz", hash = "sha256:ebbc3906f554cb4bf8f513e43e8a33e8b6a3d4a0ef13617a0e14c3dda8ef52a8", size = 7201, upload-time = "2023-10-23T07:44:18.523Z" } - -[[package]] -name = "alibabacloud-openplatform20191219" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-endpoint-util" }, - { name = "alibabacloud-openapi-util" }, - { name = "alibabacloud-tea-openapi" }, - { name = "alibabacloud-tea-util" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4f/bf/f7fa2f3657ed352870f442434cb2f27b7f70dcd52a544a1f3998eeaf6d71/alibabacloud_openplatform20191219-2.0.0.tar.gz", hash = "sha256:e67f4c337b7542538746592c6a474bd4ae3a9edccdf62e11a32ca61fad3c9020", size = 5038, upload-time = "2022-09-21T06:16:10.683Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/e5/18c75213551eeca9db1f6b41ddcc0bd87b5b6508c75a67f05cd8671847b4/alibabacloud_openplatform20191219-2.0.0-py3-none-any.whl", hash = "sha256:873821c45bca72a6c6ec7a906c9cb21554c122e88893bbac3986934dab30dd36", size = 5204, upload-time = "2022-09-21T06:16:07.844Z" }, -] - -[[package]] -name = "alibabacloud-oss-sdk" -version = "0.1.1" +version = "5.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, - { name = "alibabacloud-oss-util" }, - { name = "alibabacloud-tea-fileform" }, - { name = "alibabacloud-tea-util" }, - { name = "alibabacloud-tea-xml" }, + { name = "alibabacloud-tea-openapi" }, + { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/d1/f442dd026908fcf55340ca694bb1d027aa91e119e76ae2fbea62f2bde4f4/alibabacloud_oss_sdk-0.1.1.tar.gz", hash = "sha256:f51a368020d0964fcc0978f96736006f49f5ab6a4a4bf4f0b8549e2c659e7358", size = 46434, upload-time = "2025-04-22T12:40:41.717Z" } - -[[package]] -name = "alibabacloud-oss-util" -version = "0.0.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, +sdist = { url = "https://files.pythonhosted.org/packages/b3/36/69333c7fb7fb5267f338371b14fdd8dbdd503717c97bbc7a6419d155ab4c/alibabacloud_gpdb20160503-5.1.0.tar.gz", hash = "sha256:086ec6d5e39b64f54d0e44bb3fd4fde1a4822a53eb9f6ff7464dff7d19b07b63", size = 295641, upload-time = "2026-03-19T10:09:02.444Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/7f/a91a2f9ad97c92fa9a6981587ea0ff789240cea05b17b17b7c244e5bac64/alibabacloud_gpdb20160503-5.1.0-py3-none-any.whl", hash = "sha256:580e4579285a54c7f04570782e0f60423a1997568684187fe88e4110acfb640e", size = 848784, upload-time = "2026-03-19T10:09:00.72Z" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/7c/d7e812b9968247a302573daebcfef95d0f9a718f7b4bfcca8d3d83e266be/alibabacloud_oss_util-0.0.6.tar.gz", hash = "sha256:d3ecec36632434bd509a113e8cf327dc23e830ac8d9dd6949926f4e334c8b5d6", size = 10008, upload-time = "2021-04-28T09:25:04.056Z" } [[package]] name = "alibabacloud-tea" @@ -260,15 +202,6 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/9a/7d/b22cb9a0d4f396ee0f3f9d7f26b76b9ed93d4101add7867a2c87ed2534f5/alibabacloud-tea-0.4.3.tar.gz", hash = "sha256:ec8053d0aa8d43ebe1deb632d5c5404339b39ec9a18a0707d57765838418504a", size = 8785, upload-time = "2025-03-24T07:34:42.958Z" } -[[package]] -name = "alibabacloud-tea-fileform" -version = "0.0.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/22/8a/ef8ddf5ee0350984cad2749414b420369fe943e15e6d96b79be45367630e/alibabacloud_tea_fileform-0.0.5.tar.gz", hash = "sha256:fd00a8c9d85e785a7655059e9651f9e91784678881831f60589172387b968ee8", size = 3961, upload-time = "2021-04-28T09:22:54.56Z" } - [[package]] name = "alibabacloud-tea-openapi" version = "0.4.3" @@ -297,15 +230,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/9e/c394b4e2104766fb28a1e44e3ed36e4c7773b4d05c868e482be99d5635c9/alibabacloud_tea_util-0.3.14-py3-none-any.whl", hash = "sha256:10d3e5c340d8f7ec69dd27345eb2fc5a1dab07875742525edf07bbe86db93bfe", size = 6697, upload-time = "2025-11-19T06:01:07.355Z" }, ] -[[package]] -name = "alibabacloud-tea-xml" -version = "0.0.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/32/eb/5e82e419c3061823f3feae9b5681588762929dc4da0176667297c2784c1a/alibabacloud_tea_xml-0.0.3.tar.gz", hash = "sha256:979cb51fadf43de77f41c69fc69c12529728919f849723eb0cd24eb7b048a90c", size = 3466, upload-time = "2025-07-01T08:04:55.144Z" } - [[package]] name = "aliyun-log-python-sdk" version = "0.9.37" @@ -570,28 +494,28 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.38.2" +version = "1.38.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/a3/20aa7c4e83f2f614e0036300f3c352775dede0655c66814da16c37b661a9/basedpyright-1.38.2.tar.gz", hash = "sha256:b433b2b8ba745ed7520cdc79a29a03682f3fb00346d272ece5944e9e5e5daa92", size = 25277019, upload-time = "2026-02-26T11:18:43.594Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/58/7abba2c743571a42b2548f07aee556ebc1e4d0bc2b277aeba1ee6c83b0af/basedpyright-1.38.3.tar.gz", hash = "sha256:9725419786afbfad8a9539527f162da02d462afad440b0412fdb3f3cdf179b90", size = 25277430, upload-time = "2026-03-17T13:10:41.526Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/12/736cab83626fea3fe65cdafb3ef3d2ee9480c56723f2fd33921537289a5e/basedpyright-1.38.2-py3-none-any.whl", hash = "sha256:153481d37fd19f9e3adedc8629d1d071b10c5f5e49321fb026b74444b7c70e24", size = 12312475, upload-time = "2026-02-26T11:18:40.373Z" }, + { url = "https://files.pythonhosted.org/packages/2c/e3/3ebb5c23bd3abb5fc2053b8a06a889aa5c1cf8cff738c78cb6c1957e90cd/basedpyright-1.38.3-py3-none-any.whl", hash = "sha256:1f15c2e489c67d6c5e896c24b6a63251195c04223a55e4568b8f8e8ed49ca830", size = 12313363, upload-time = "2026-03-17T13:10:47.344Z" }, ] [[package]] name = "bce-python-sdk" -version = "0.9.63" +version = "0.9.64" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/ab/4c2927b01a97562af6a296b722eee79658335795f341a395a12742d5e1a3/bce_python_sdk-0.9.63.tar.gz", hash = "sha256:0c80bc3ac128a0a144bae3b8dff1f397f42c30b36f7677e3a39d8df8e77b1088", size = 284419, upload-time = "2026-03-06T14:54:06.592Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/33/047e9c1a6c97e0cd4d93a6490abd8fbc2ccd13569462fc0228699edc08bc/bce_python_sdk-0.9.64.tar.gz", hash = "sha256:901bf787c26ad35855a80d65e58d7584c8541f7f0f2af20847830e572e5b622e", size = 287125, upload-time = "2026-03-17T11:24:29.345Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/a4/501e978776c7060aa8ba77e68536597e754d938bcdbe1826618acebfbddf/bce_python_sdk-0.9.63-py3-none-any.whl", hash = "sha256:ec66eee8807c6aa4036412592da7e8c9e2cd7fdec494190986288ac2195d8276", size = 400305, upload-time = "2026-03-06T14:53:52.887Z" }, + { url = "https://files.pythonhosted.org/packages/48/7f/dd289582f37ab4effea47b2a8503880db4781ca0fc8e0a8ed5ff493359e5/bce_python_sdk-0.9.64-py3-none-any.whl", hash = "sha256:eaad97e4f0e7d613ae978da3cdc5294e9f724ffca2735f79820037fa1317cd6d", size = 402233, upload-time = "2026-03-17T11:24:24.673Z" }, ] [[package]] @@ -660,14 +584,14 @@ wheels = [ [[package]] name = "bleach" -version = "6.2.0" +version = "6.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "webencodings" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/9a/0e33f5054c54d349ea62c277191c020c2d6ef1d65ab2cb1993f91ec846d1/bleach-6.2.0.tar.gz", hash = "sha256:123e894118b8a599fd80d3ec1a6d4cc7ce4e5882b1317a7e1ba69b56e95f991f", size = 203083, upload-time = "2024-10-29T18:30:40.477Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/55/96142937f66150805c25c4d0f31ee4132fd33497753400734f9dfdcbdc66/bleach-6.2.0-py3-none-any.whl", hash = "sha256:117d9c6097a7c3d22fd578fcd8d35ff1e125df6736f554da4e432fdd63f31e5e", size = 163406, upload-time = "2024-10-29T18:30:38.186Z" }, + { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, ] [[package]] @@ -706,30 +630,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.68" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/ae/60c642aa5413e560b671da825329f510b29a77274ed0f580bde77562294d/boto3-1.42.68.tar.gz", hash = "sha256:3f349f967ab38c23425626d130962bcb363e75f042734fe856ea8c5a00eef03c", size = 112761, upload-time = "2026-03-13T19:32:17.137Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/8b/d00575be514744ca4839e7d85bf4a8a3c7b6b4574433291e58d14c68ae09/boto3-1.42.73.tar.gz", hash = "sha256:d37b58d6cd452ca808dd6823ae19ca65b6244096c5125ef9052988b337298bae", size = 112775, upload-time = "2026-03-20T19:39:52.814Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/f6/dc6e993479dbb597d68223fbf61cb026511737696b15bd7d2a33e9b2c24f/boto3-1.42.68-py3-none-any.whl", hash = "sha256:dbff353eb7dc93cbddd7926ed24793e0174c04adbe88860dfa639568442e4962", size = 140556, upload-time = "2026-03-13T19:32:14.951Z" }, + { url = "https://files.pythonhosted.org/packages/aa/05/1fcf03d90abaa3d0b42a6bfd10231dd709493ecbacf794aa2eea5eae6841/boto3-1.42.73-py3-none-any.whl", hash = "sha256:1f81b79b873f130eeab14bb556417a7c66d38f3396b7f2fe3b958b3f9094f455", size = 140556, upload-time = "2026-03-20T19:39:50.298Z" }, ] [[package]] name = "boto3-stubs" -version = "1.42.68" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/8c/dd4b0c95ff008bed5a35ab411452ece121b355539d2a0b6dcd62a0c47be5/boto3_stubs-1.42.68.tar.gz", hash = "sha256:96ad1020735619483fb9b4da7a5e694b460bf2e18f84a34d5d175d0ffe8c4653", size = 101372, upload-time = "2026-03-13T19:49:54.867Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/c3/fcc47102c63278af25ad57c93d97dc393f4dbc54c0117a29c78f2b96ec1e/boto3_stubs-1.42.73.tar.gz", hash = "sha256:36f625769b5505c4bc627f16244b98de9e10dae3ac36f1aa0f0ebe2f201dc138", size = 101373, upload-time = "2026-03-20T19:59:51.463Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/15/3ca5848917214a168134512a5b45f856a56e913659888947a052e02031b5/boto3_stubs-1.42.68-py3-none-any.whl", hash = "sha256:ed7f98334ef7b2377fa8532190e63dc2c6d1dc895e3d7cb3d6d1c83771b81bf6", size = 70011, upload-time = "2026-03-13T19:49:42.801Z" }, + { url = "https://files.pythonhosted.org/packages/4b/57/d570ba61a2a0c7fe0c8667e41269a0480293cb53e1786d6661a2bd827fc5/boto3_stubs-1.42.73-py3-none-any.whl", hash = "sha256:bd658429069d8215247fc3abc003220cd875c24ab6eda7b3405090408afaacdf", size = 70009, upload-time = "2026-03-20T19:59:43.786Z" }, ] [package.optional-dependencies] @@ -739,16 +663,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.42.68" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3f/22/87502d5fbbfa8189406a617b30b1e2a3dc0ab2669f7268e91b385c1c1c7a/botocore-1.42.68.tar.gz", hash = "sha256:3951c69e12ac871dda245f48dac5c7dd88ea1bfdd74a8879ec356cf2874b806a", size = 14994514, upload-time = "2026-03-13T19:32:03.577Z" } +sdist = { url = "https://files.pythonhosted.org/packages/28/23/0c88ca116ef63b1ae77c901cd5d2095d22a8dbde9e80df74545db4a061b4/botocore-1.42.73.tar.gz", hash = "sha256:575858641e4949aaf2af1ced145b8524529edf006d075877af6b82ff96ad854c", size = 15008008, upload-time = "2026-03-20T19:39:40.082Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/2a/1428f6594799780fe6ee845d8e6aeffafe026cd16a70c878684e2dcbbfc8/botocore-1.42.68-py3-none-any.whl", hash = "sha256:9df7da26374601f890e2f115bfa573d65bf15b25fe136bb3aac809f6145f52ab", size = 14668816, upload-time = "2026-03-13T19:31:58.572Z" }, + { url = "https://files.pythonhosted.org/packages/8e/65/971f3d55015f4d133a6ff3ad74cd39f4b8dd8f53f7775a3c2ad378ea5145/botocore-1.42.73-py3-none-any.whl", hash = "sha256:7b62e2a12f7a1b08eb7360eecd23bb16fe3b7ab7f5617cf91b25476c6f86a0fe", size = 14681861, upload-time = "2026-03-20T19:39:35.341Z" }, ] [[package]] @@ -1290,41 +1214,41 @@ wheels = [ [[package]] name = "coverage" -version = "7.13.4" +version = "7.13.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/56/95b7e30fa389756cb56630faa728da46a27b8c6eb46f9d557c68fff12b65/coverage-7.13.4.tar.gz", hash = "sha256:e5c8f6ed1e61a8b2dcdf31eb0b9bbf0130750ca79c1c49eb898e2ad86f5ccc91", size = 827239, upload-time = "2026-02-09T12:59:03.86Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179", size = 915967, upload-time = "2026-03-17T10:33:18.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/ad/b59e5b451cf7172b8d1043dc0fa718f23aab379bc1521ee13d4bd9bfa960/coverage-7.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d490ba50c3f35dd7c17953c68f3270e7ccd1c6642e2d2afe2d8e720b98f5a053", size = 219278, upload-time = "2026-02-09T12:56:31.673Z" }, - { url = "https://files.pythonhosted.org/packages/f1/17/0cb7ca3de72e5f4ef2ec2fa0089beafbcaaaead1844e8b8a63d35173d77d/coverage-7.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19bc3c88078789f8ef36acb014d7241961dbf883fd2533d18cb1e7a5b4e28b11", size = 219783, upload-time = "2026-02-09T12:56:33.104Z" }, - { url = "https://files.pythonhosted.org/packages/ab/63/325d8e5b11e0eaf6d0f6a44fad444ae58820929a9b0de943fa377fe73e85/coverage-7.13.4-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3998e5a32e62fdf410c0dbd3115df86297995d6e3429af80b8798aad894ca7aa", size = 250200, upload-time = "2026-02-09T12:56:34.474Z" }, - { url = "https://files.pythonhosted.org/packages/76/53/c16972708cbb79f2942922571a687c52bd109a7bd51175aeb7558dff2236/coverage-7.13.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e264226ec98e01a8e1054314af91ee6cde0eacac4f465cc93b03dbe0bce2fd7", size = 252114, upload-time = "2026-02-09T12:56:35.749Z" }, - { url = "https://files.pythonhosted.org/packages/eb/c2/7ab36d8b8cc412bec9ea2d07c83c48930eb4ba649634ba00cb7e4e0f9017/coverage-7.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a3aa4e7b9e416774b21797365b358a6e827ffadaaca81b69ee02946852449f00", size = 254220, upload-time = "2026-02-09T12:56:37.796Z" }, - { url = "https://files.pythonhosted.org/packages/d6/4d/cf52c9a3322c89a0e6febdfbc83bb45c0ed3c64ad14081b9503adee702e7/coverage-7.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:71ca20079dd8f27fcf808817e281e90220475cd75115162218d0e27549f95fef", size = 256164, upload-time = "2026-02-09T12:56:39.016Z" }, - { url = "https://files.pythonhosted.org/packages/78/e9/eb1dd17bd6de8289df3580e967e78294f352a5df8a57ff4671ee5fc3dcd0/coverage-7.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e2f25215f1a359ab17320b47bcdaca3e6e6356652e8256f2441e4ef972052903", size = 250325, upload-time = "2026-02-09T12:56:40.668Z" }, - { url = "https://files.pythonhosted.org/packages/71/07/8c1542aa873728f72267c07278c5cc0ec91356daf974df21335ccdb46368/coverage-7.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d65b2d373032411e86960604dc4edac91fdfb5dca539461cf2cbe78327d1e64f", size = 251913, upload-time = "2026-02-09T12:56:41.97Z" }, - { url = "https://files.pythonhosted.org/packages/74/d7/c62e2c5e4483a748e27868e4c32ad3daa9bdddbba58e1bc7a15e252baa74/coverage-7.13.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94eb63f9b363180aff17de3e7c8760c3ba94664ea2695c52f10111244d16a299", size = 249974, upload-time = "2026-02-09T12:56:43.323Z" }, - { url = "https://files.pythonhosted.org/packages/98/9f/4c5c015a6e98ced54efd0f5cf8d31b88e5504ecb6857585fc0161bb1e600/coverage-7.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e856bf6616714c3a9fbc270ab54103f4e685ba236fa98c054e8f87f266c93505", size = 253741, upload-time = "2026-02-09T12:56:45.155Z" }, - { url = "https://files.pythonhosted.org/packages/bd/59/0f4eef89b9f0fcd9633b5d350016f54126ab49426a70ff4c4e87446cabdc/coverage-7.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:65dfcbe305c3dfe658492df2d85259e0d79ead4177f9ae724b6fb245198f55d6", size = 249695, upload-time = "2026-02-09T12:56:46.636Z" }, - { url = "https://files.pythonhosted.org/packages/b5/2c/b7476f938deb07166f3eb281a385c262675d688ff4659ad56c6c6b8e2e70/coverage-7.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b507778ae8a4c915436ed5c2e05b4a6cecfa70f734e19c22a005152a11c7b6a9", size = 250599, upload-time = "2026-02-09T12:56:48.13Z" }, - { url = "https://files.pythonhosted.org/packages/b8/34/c3420709d9846ee3785b9f2831b4d94f276f38884032dca1457fa83f7476/coverage-7.13.4-cp311-cp311-win32.whl", hash = "sha256:784fc3cf8be001197b652d51d3fd259b1e2262888693a4636e18879f613a62a9", size = 221780, upload-time = "2026-02-09T12:56:50.479Z" }, - { url = "https://files.pythonhosted.org/packages/61/08/3d9c8613079d2b11c185b865de9a4c1a68850cfda2b357fae365cf609f29/coverage-7.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:2421d591f8ca05b308cf0092807308b2facbefe54af7c02ac22548b88b95c98f", size = 222715, upload-time = "2026-02-09T12:56:51.815Z" }, - { url = "https://files.pythonhosted.org/packages/18/1a/54c3c80b2f056164cc0a6cdcb040733760c7c4be9d780fe655f356f433e4/coverage-7.13.4-cp311-cp311-win_arm64.whl", hash = "sha256:79e73a76b854d9c6088fe5d8b2ebe745f8681c55f7397c3c0a016192d681045f", size = 221385, upload-time = "2026-02-09T12:56:53.194Z" }, - { url = "https://files.pythonhosted.org/packages/d1/81/4ce2fdd909c5a0ed1f6dedb88aa57ab79b6d1fbd9b588c1ac7ef45659566/coverage-7.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02231499b08dabbe2b96612993e5fc34217cdae907a51b906ac7fca8027a4459", size = 219449, upload-time = "2026-02-09T12:56:54.889Z" }, - { url = "https://files.pythonhosted.org/packages/5d/96/5238b1efc5922ddbdc9b0db9243152c09777804fb7c02ad1741eb18a11c0/coverage-7.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40aa8808140e55dc022b15d8aa7f651b6b3d68b365ea0398f1441e0b04d859c3", size = 219810, upload-time = "2026-02-09T12:56:56.33Z" }, - { url = "https://files.pythonhosted.org/packages/78/72/2f372b726d433c9c35e56377cf1d513b4c16fe51841060d826b95caacec1/coverage-7.13.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5b856a8ccf749480024ff3bd7310adaef57bf31fd17e1bfc404b7940b6986634", size = 251308, upload-time = "2026-02-09T12:56:57.858Z" }, - { url = "https://files.pythonhosted.org/packages/5d/a0/2ea570925524ef4e00bb6c82649f5682a77fac5ab910a65c9284de422600/coverage-7.13.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c048ea43875fbf8b45d476ad79f179809c590ec7b79e2035c662e7afa3192e3", size = 254052, upload-time = "2026-02-09T12:56:59.754Z" }, - { url = "https://files.pythonhosted.org/packages/e8/ac/45dc2e19a1939098d783c846e130b8f862fbb50d09e0af663988f2f21973/coverage-7.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b7b38448866e83176e28086674fe7368ab8590e4610fb662b44e345b86d63ffa", size = 255165, upload-time = "2026-02-09T12:57:01.287Z" }, - { url = "https://files.pythonhosted.org/packages/2d/4d/26d236ff35abc3b5e63540d3386e4c3b192168c1d96da5cb2f43c640970f/coverage-7.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:de6defc1c9badbf8b9e67ae90fd00519186d6ab64e5cc5f3d21359c2a9b2c1d3", size = 257432, upload-time = "2026-02-09T12:57:02.637Z" }, - { url = "https://files.pythonhosted.org/packages/ec/55/14a966c757d1348b2e19caf699415a2a4c4f7feaa4bbc6326a51f5c7dd1b/coverage-7.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7eda778067ad7ffccd23ecffce537dface96212576a07924cbf0d8799d2ded5a", size = 251716, upload-time = "2026-02-09T12:57:04.056Z" }, - { url = "https://files.pythonhosted.org/packages/77/33/50116647905837c66d28b2af1321b845d5f5d19be9655cb84d4a0ea806b4/coverage-7.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e87f6c587c3f34356c3759f0420693e35e7eb0e2e41e4c011cb6ec6ecbbf1db7", size = 253089, upload-time = "2026-02-09T12:57:05.503Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b4/8efb11a46e3665d92635a56e4f2d4529de6d33f2cb38afd47d779d15fc99/coverage-7.13.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8248977c2e33aecb2ced42fef99f2d319e9904a36e55a8a68b69207fb7e43edc", size = 251232, upload-time = "2026-02-09T12:57:06.879Z" }, - { url = "https://files.pythonhosted.org/packages/51/24/8cd73dd399b812cc76bb0ac260e671c4163093441847ffe058ac9fda1e32/coverage-7.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:25381386e80ae727608e662474db537d4df1ecd42379b5ba33c84633a2b36d47", size = 255299, upload-time = "2026-02-09T12:57:08.245Z" }, - { url = "https://files.pythonhosted.org/packages/03/94/0a4b12f1d0e029ce1ccc1c800944a9984cbe7d678e470bb6d3c6bc38a0da/coverage-7.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:ee756f00726693e5ba94d6df2bdfd64d4852d23b09bb0bc700e3b30e6f333985", size = 250796, upload-time = "2026-02-09T12:57:10.142Z" }, - { url = "https://files.pythonhosted.org/packages/73/44/6002fbf88f6698ca034360ce474c406be6d5a985b3fdb3401128031eef6b/coverage-7.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fdfc1e28e7c7cdce44985b3043bc13bbd9c747520f94a4d7164af8260b3d91f0", size = 252673, upload-time = "2026-02-09T12:57:12.197Z" }, - { url = "https://files.pythonhosted.org/packages/de/c6/a0279f7c00e786be75a749a5674e6fa267bcbd8209cd10c9a450c655dfa7/coverage-7.13.4-cp312-cp312-win32.whl", hash = "sha256:01d4cbc3c283a17fc1e42d614a119f7f438eabb593391283adca8dc86eff1246", size = 221990, upload-time = "2026-02-09T12:57:14.085Z" }, - { url = "https://files.pythonhosted.org/packages/77/4e/c0a25a425fcf5557d9abd18419c95b63922e897bc86c1f327f155ef234a9/coverage-7.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:9401ebc7ef522f01d01d45532c68c5ac40fb27113019b6b7d8b208f6e9baa126", size = 222800, upload-time = "2026-02-09T12:57:15.944Z" }, - { url = "https://files.pythonhosted.org/packages/47/ac/92da44ad9a6f4e3a7debd178949d6f3769bedca33830ce9b1dcdab589a37/coverage-7.13.4-cp312-cp312-win_arm64.whl", hash = "sha256:b1ec7b6b6e93255f952e27ab58fbc68dcc468844b16ecbee881aeb29b6ab4d8d", size = 221415, upload-time = "2026-02-09T12:57:17.497Z" }, - { url = "https://files.pythonhosted.org/packages/0d/4a/331fe2caf6799d591109bb9c08083080f6de90a823695d412a935622abb2/coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0", size = 211242, upload-time = "2026-02-09T12:59:02.032Z" }, + { url = "https://files.pythonhosted.org/packages/4b/37/d24c8f8220ff07b839b2c043ea4903a33b0f455abe673ae3c03bbdb7f212/coverage-7.13.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66a80c616f80181f4d643b0f9e709d97bcea413ecd9631e1dedc7401c8e6695d", size = 219381, upload-time = "2026-03-17T10:30:14.68Z" }, + { url = "https://files.pythonhosted.org/packages/35/8b/cd129b0ca4afe886a6ce9d183c44d8301acbd4ef248622e7c49a23145605/coverage-7.13.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:145ede53ccbafb297c1c9287f788d1bc3efd6c900da23bf6931b09eafc931587", size = 219880, upload-time = "2026-03-17T10:30:16.231Z" }, + { url = "https://files.pythonhosted.org/packages/55/2f/e0e5b237bffdb5d6c530ce87cc1d413a5b7d7dfd60fb067ad6d254c35c76/coverage-7.13.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0672854dc733c342fa3e957e0605256d2bf5934feeac328da9e0b5449634a642", size = 250303, upload-time = "2026-03-17T10:30:17.748Z" }, + { url = "https://files.pythonhosted.org/packages/92/be/b1afb692be85b947f3401375851484496134c5554e67e822c35f28bf2fbc/coverage-7.13.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ec10e2a42b41c923c2209b846126c6582db5e43a33157e9870ba9fb70dc7854b", size = 252218, upload-time = "2026-03-17T10:30:19.804Z" }, + { url = "https://files.pythonhosted.org/packages/da/69/2f47bb6fa1b8d1e3e5d0c4be8ccb4313c63d742476a619418f85740d597b/coverage-7.13.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be3d4bbad9d4b037791794ddeedd7d64a56f5933a2c1373e18e9e568b9141686", size = 254326, upload-time = "2026-03-17T10:30:21.321Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d0/79db81da58965bd29dabc8f4ad2a2af70611a57cba9d1ec006f072f30a54/coverage-7.13.5-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4d2afbc5cc54d286bfb54541aa50b64cdb07a718227168c87b9e2fb8f25e1743", size = 256267, upload-time = "2026-03-17T10:30:23.094Z" }, + { url = "https://files.pythonhosted.org/packages/e5/32/d0d7cc8168f91ddab44c0ce4806b969df5f5fdfdbb568eaca2dbc2a04936/coverage-7.13.5-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3ad050321264c49c2fa67bb599100456fc51d004b82534f379d16445da40fb75", size = 250430, upload-time = "2026-03-17T10:30:25.311Z" }, + { url = "https://files.pythonhosted.org/packages/4d/06/a055311d891ddbe231cd69fdd20ea4be6e3603ffebddf8704b8ca8e10a3c/coverage-7.13.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7300c8a6d13335b29bb76d7651c66af6bd8658517c43499f110ddc6717bfc209", size = 252017, upload-time = "2026-03-17T10:30:27.284Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f6/d0fd2d21e29a657b5f77a2fe7082e1568158340dceb941954f776dce1b7b/coverage-7.13.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:eb07647a5738b89baab047f14edd18ded523de60f3b30e75c2acc826f79c839a", size = 250080, upload-time = "2026-03-17T10:30:29.481Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ab/0d7fb2efc2e9a5eb7ddcc6e722f834a69b454b7e6e5888c3a8567ecffb31/coverage-7.13.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9adb6688e3b53adffefd4a52d72cbd8b02602bfb8f74dcd862337182fd4d1a4e", size = 253843, upload-time = "2026-03-17T10:30:31.301Z" }, + { url = "https://files.pythonhosted.org/packages/ba/6f/7467b917bbf5408610178f62a49c0ed4377bb16c1657f689cc61470da8ce/coverage-7.13.5-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7c8d4bc913dd70b93488d6c496c77f3aff5ea99a07e36a18f865bca55adef8bd", size = 249802, upload-time = "2026-03-17T10:30:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/75/2c/1172fb689df92135f5bfbbd69fc83017a76d24ea2e2f3a1154007e2fb9f8/coverage-7.13.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0e3c426ffc4cd952f54ee9ffbdd10345709ecc78a3ecfd796a57236bfad0b9b8", size = 250707, upload-time = "2026-03-17T10:30:35.2Z" }, + { url = "https://files.pythonhosted.org/packages/67/21/9ac389377380a07884e3b48ba7a620fcd9dbfaf1d40565facdc6b36ec9ef/coverage-7.13.5-cp311-cp311-win32.whl", hash = "sha256:259b69bb83ad9894c4b25be2528139eecba9a82646ebdda2d9db1ba28424a6bf", size = 221880, upload-time = "2026-03-17T10:30:36.775Z" }, + { url = "https://files.pythonhosted.org/packages/af/7f/4cd8a92531253f9d7c1bbecd9fa1b472907fb54446ca768c59b531248dc5/coverage-7.13.5-cp311-cp311-win_amd64.whl", hash = "sha256:258354455f4e86e3e9d0d17571d522e13b4e1e19bf0f8596bcf9476d61e7d8a9", size = 222816, upload-time = "2026-03-17T10:30:38.891Z" }, + { url = "https://files.pythonhosted.org/packages/12/a6/1d3f6155fb0010ca68eba7fe48ca6c9da7385058b77a95848710ecf189b1/coverage-7.13.5-cp311-cp311-win_arm64.whl", hash = "sha256:bff95879c33ec8da99fc9b6fe345ddb5be6414b41d6d1ad1c8f188d26f36e028", size = 221483, upload-time = "2026-03-17T10:30:40.463Z" }, + { url = "https://files.pythonhosted.org/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01", size = 219554, upload-time = "2026-03-17T10:30:42.208Z" }, + { url = "https://files.pythonhosted.org/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422", size = 219908, upload-time = "2026-03-17T10:30:43.906Z" }, + { url = "https://files.pythonhosted.org/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f", size = 251419, upload-time = "2026-03-17T10:30:45.545Z" }, + { url = "https://files.pythonhosted.org/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5", size = 254159, upload-time = "2026-03-17T10:30:47.204Z" }, + { url = "https://files.pythonhosted.org/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376", size = 255270, upload-time = "2026-03-17T10:30:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256", size = 257538, upload-time = "2026-03-17T10:30:50.77Z" }, + { url = "https://files.pythonhosted.org/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c", size = 251821, upload-time = "2026-03-17T10:30:52.5Z" }, + { url = "https://files.pythonhosted.org/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5", size = 253191, upload-time = "2026-03-17T10:30:54.543Z" }, + { url = "https://files.pythonhosted.org/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09", size = 251337, upload-time = "2026-03-17T10:30:56.663Z" }, + { url = "https://files.pythonhosted.org/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9", size = 255404, upload-time = "2026-03-17T10:30:58.427Z" }, + { url = "https://files.pythonhosted.org/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf", size = 250903, upload-time = "2026-03-17T10:31:00.093Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c", size = 252780, upload-time = "2026-03-17T10:31:01.916Z" }, + { url = "https://files.pythonhosted.org/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf", size = 222093, upload-time = "2026-03-17T10:31:03.642Z" }, + { url = "https://files.pythonhosted.org/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810", size = 222900, upload-time = "2026-03-17T10:31:05.651Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de", size = 221515, upload-time = "2026-03-17T10:31:07.293Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, ] [package.optional-dependencies] @@ -1533,7 +1457,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.13.1" +version = "1.13.2" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1605,6 +1529,7 @@ dependencies = [ { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pypandoc" }, { name = "pypdfium2" }, { name = "python-docx" }, { name = "python-dotenv" }, @@ -1743,8 +1668,8 @@ requires-dist = [ { name = "arize-phoenix-otel", specifier = "~=0.15.0" }, { name = "azure-identity", specifier = "==1.25.3" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, - { name = "bleach", specifier = "~=6.2.0" }, - { name = "boto3", specifier = "==1.42.68" }, + { name = "bleach", specifier = "~=6.3.0" }, + { name = "boto3", specifier = "==1.42.73" }, { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.6.2" }, @@ -1762,7 +1687,7 @@ requires-dist = [ { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.3.0" }, { name = "google-api-core", specifier = ">=2.19.1" }, - { name = "google-api-python-client", specifier = "==2.192.0" }, + { name = "google-api-python-client", specifier = "==2.193.0" }, { name = "google-auth", specifier = ">=2.47.0" }, { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, @@ -1775,7 +1700,7 @@ requires-dist = [ { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.7.16" }, - { name = "litellm", specifier = "==1.82.2" }, + { name = "litellm", specifier = "==1.82.6" }, { name = "markdown", specifier = "~=3.10.2" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, @@ -1807,18 +1732,19 @@ requires-dist = [ { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, + { name = "pypandoc", specifier = "~=1.13" }, { name = "pypdfium2", specifier = "==5.6.0" }, { name = "python-docx", specifier = "~=1.2.0" }, { name = "python-dotenv", specifier = "==1.2.2" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, { name = "redis", extras = ["hiredis"], specifier = "~=7.3.0" }, - { name = "resend", specifier = "~=2.23.0" }, + { name = "resend", specifier = "~=2.26.0" }, { name = "sendgrid", specifier = "~=6.12.3" }, - { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.54.0" }, + { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, { name = "sqlalchemy", specifier = "~=2.0.29" }, { name = "sseclient-py", specifier = "~=1.9.0" }, - { name = "starlette", specifier = "==0.52.1" }, + { name = "starlette", specifier = "==1.0.0" }, { name = "tiktoken", specifier = "~=0.12.0" }, { name = "transformers", specifier = "~=5.3.0" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, @@ -1841,10 +1767,10 @@ dev = [ { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.19.1" }, { name = "pandas-stubs", specifier = "~=3.0.0" }, - { name = "pyrefly", specifier = ">=0.55.0" }, + { name = "pyrefly", specifier = ">=0.57.1" }, { name = "pytest", specifier = "~=9.0.2" }, { name = "pytest-benchmark", specifier = "~=5.2.3" }, - { name = "pytest-cov", specifier = "~=7.0.0" }, + { name = "pytest-cov", specifier = "~=7.1.0" }, { name = "pytest-env", specifier = "~=1.6.0" }, { name = "pytest-mock", specifier = "~=3.15.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, @@ -1910,7 +1836,7 @@ tools = [ { name = "nltk", specifier = "~=3.9.1" }, ] vdb = [ - { name = "alibabacloud-gpdb20160503", specifier = "~=3.8.0" }, + { name = "alibabacloud-gpdb20160503", specifier = "~=5.1.0" }, { name = "alibabacloud-tea-openapi", specifier = "~=0.4.3" }, { name = "chromadb", specifier = "==0.5.20" }, { name = "clickhouse-connect", specifier = "~=0.14.1" }, @@ -2499,7 +2425,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.192.0" +version = "2.193.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2508,9 +2434,9 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/d8/489052a40935e45b9b5b3d6accc14b041360c1507bdc659c2e1a19aaa3ff/google_api_python_client-2.192.0.tar.gz", hash = "sha256:d48cfa6078fadea788425481b007af33fe0ab6537b78f37da914fb6fc112eb27", size = 14209505, upload-time = "2026-03-05T15:17:01.598Z" } +sdist = { url = "https://files.pythonhosted.org/packages/90/f4/e14b6815d3b1885328dd209676a3a4c704882743ac94e18ef0093894f5c8/google_api_python_client-2.193.0.tar.gz", hash = "sha256:8f88d16e89d11341e0a8b199cafde0fb7e6b44260dffb88d451577cbd1bb5d33", size = 14281006, upload-time = "2026-03-17T18:25:29.415Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/76/ec4128f00fefb9011635ae2abc67d7dacd05c8559378f8f05f0c907c38d8/google_api_python_client-2.192.0-py3-none-any.whl", hash = "sha256:63a57d4457cd97df1d63eb89c5fda03c5a50588dcbc32c0115dd1433c08f4b62", size = 14783267, upload-time = "2026-03-05T15:16:58.804Z" }, + { url = "https://files.pythonhosted.org/packages/f0/6d/fe75167797790a56d17799b75e1129bb93f7ff061efc7b36e9731bd4be2b/google_api_python_client-2.193.0-py3-none-any.whl", hash = "sha256:c42aa324b822109901cfecab5dc4fc3915d35a7b376835233c916c70610322db", size = 14856490, upload-time = "2026-03-17T18:25:26.608Z" }, ] [[package]] @@ -2546,7 +2472,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.141.0" +version = "1.142.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2562,9 +2488,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, + { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, ] [[package]] @@ -2617,7 +2543,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.9.0" +version = "3.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2627,9 +2553,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/b1/4f0798e88285b50dfc60ed3a7de071def538b358db2da468c2e0deecbb40/google_cloud_storage-3.9.0.tar.gz", hash = "sha256:f2d8ca7db2f652be757e92573b2196e10fbc09649b5c016f8b422ad593c641cc", size = 17298544, upload-time = "2026-02-02T13:36:34.119Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/0b/816a6ae3c9fd096937d2e5f9670558908811d57d59ddf69dd4b83b326fd1/google_cloud_storage-3.9.0-py3-none-any.whl", hash = "sha256:2dce75a9e8b3387078cbbdad44757d410ecdb916101f8ba308abf202b6968066", size = 321324, upload-time = "2026-02-02T13:36:32.271Z" }, + { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, ] [[package]] @@ -3458,7 +3384,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.7.17" +version = "0.7.22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -3471,9 +3397,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/79/81041dde07a974e728db7def23c1c7255950b8874102925cc77093bc847d/langsmith-0.7.17.tar.gz", hash = "sha256:6c1b0c2863cdd6636d2a58b8d5b1b80060703d98cac2593f4233e09ac25b5a9d", size = 1132228, upload-time = "2026-03-12T20:41:10.808Z" } +sdist = { url = "https://files.pythonhosted.org/packages/be/2a/2d5e6c67396fd228670af278c4da7bd6db2b8d11deaf6f108490b6d3f561/langsmith-0.7.22.tar.gz", hash = "sha256:35bfe795d648b069958280760564632fd28ebc9921c04f3e209c0db6a6c7dc04", size = 1134923, upload-time = "2026-03-19T22:45:23.492Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/31/62689d57f4d25792bd6a3c05c868771899481be2f3e31f9e71d31e1ac4ab/langsmith-0.7.17-py3-none-any.whl", hash = "sha256:cbec10460cb6c6ecc94c18c807be88a9984838144ae6c4693c9f859f378d7d02", size = 359147, upload-time = "2026-03-12T20:41:08.758Z" }, + { url = "https://files.pythonhosted.org/packages/1a/94/1f5d72655ab6534129540843776c40eff757387b88e798d8b3bf7e313fd4/langsmith-0.7.22-py3-none-any.whl", hash = "sha256:6e9d5148314d74e86748cb9d3898632cad0320c9323d95f70f969e5bc078eee4", size = 359927, upload-time = "2026-03-19T22:45:21.603Z" }, ] [[package]] @@ -3521,7 +3447,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.82.2" +version = "1.82.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3537,9 +3463,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/12/010a86643f12ac0b004032d5927c260094299a84ed38b5ed20a8f8c7e3c4/litellm-1.82.2.tar.gz", hash = "sha256:f5f4c4049f344a88bf80b2e421bb927807687c99624515d7ff4152d533ec9dcb", size = 17353218, upload-time = "2026-03-13T21:24:24.5Z" } +sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/e4/87e3ca82a8bf6e6bfffb42a539a1350dd6ced1b7169397bd439ba56fde10/litellm-1.82.2-py3-none-any.whl", hash = "sha256:641ed024774fa3d5b4dd9347f0efb1e31fa422fba2a6500aabedee085d1194cb", size = 15524224, upload-time = "2026-03-13T21:24:21.288Z" }, + { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, ] [[package]] @@ -4536,7 +4462,7 @@ wheels = [ [[package]] name = "opik" -version = "1.10.39" +version = "1.10.45" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4555,9 +4481,9 @@ dependencies = [ { name = "tqdm" }, { name = "uuid6" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/0f/b1e00a18cac16b4f36bf6cecc2de962fda810a9416d1159c48f46b81f5ec/opik-1.10.39.tar.gz", hash = "sha256:4d808eb2137070fc5d92a3bed3c3100d9cccfb35f4f0b71ea9990733f293dbb2", size = 780312, upload-time = "2026-03-12T14:08:25.746Z" } +sdist = { url = "https://files.pythonhosted.org/packages/85/17/edea6308347cec62e6828de7c573c596559c502b54fa4f0c88a52e2e81f5/opik-1.10.45.tar.gz", hash = "sha256:d8d8627ba03d12def46965e03d58f611daaf5cf878b3d087c53fe1159788c140", size = 789876, upload-time = "2026-03-20T11:35:12.457Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/24/0f4404907a98b4aec4508504570a78a61a3a8b5e451c67326632695ba8e6/opik-1.10.39-py3-none-any.whl", hash = "sha256:a72d735b9afac62e5262294b2f704aca89ec31f5c9beda17504815f7423870c3", size = 1317833, upload-time = "2026-03-12T14:08:23.954Z" }, + { url = "https://files.pythonhosted.org/packages/b7/17/150e9eecfa28cb23f7a0bfe83ae1486a11022b97fe6d12328b455784658d/opik-1.10.45-py3-none-any.whl", hash = "sha256:e8050d9e5e0d92ff587f156eacbdd02099897f39cfe79a98380b6c8ae9906b95", size = 1337714, upload-time = "2026-03-20T11:35:10.237Z" }, ] [[package]] @@ -5273,15 +5199,15 @@ wheels = [ [[package]] name = "pydantic-extra-types" -version = "2.11.0" +version = "2.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/35/2fee58b1316a73e025728583d3b1447218a97e621933fc776fb8c0f2ebdd/pydantic_extra_types-2.11.0.tar.gz", hash = "sha256:4e9991959d045b75feb775683437a97991d02c138e00b59176571db9ce634f0e", size = 157226, upload-time = "2025-12-31T16:18:27.944Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/17/fabd56da47096d240dd45ba627bead0333b0cf0ee8ada9bec579287dadf3/pydantic_extra_types-2.11.0-py3-none-any.whl", hash = "sha256:84b864d250a0fc62535b7ec591e36f2c5b4d1325fa0017eb8cda9aeb63b374a6", size = 74296, upload-time = "2025-12-31T16:18:26.38Z" }, + { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, ] [[package]] @@ -5380,6 +5306,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/7d/037401cecb34728d1c28ea05e196ea3c9d50a1ce0f2172e586e075ff55d8/pyobvector-0.2.25-py3-none-any.whl", hash = "sha256:ae0153f99bd0222783ed7e3951efc31a0d2b462d926b6f86ebd2033409aede8f", size = 64663, upload-time = "2026-03-10T07:18:29.789Z" }, ] +[[package]] +name = "pypandoc" +version = "1.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/d6/410615fc433e5d1eacc00db2044ae2a9c82302df0d35366fe2bd15de024d/pypandoc-1.17.tar.gz", hash = "sha256:51179abfd6e582a25ed03477541b48836b5bba5a4c3b282a547630793934d799", size = 69071, upload-time = "2026-03-14T22:39:07.21Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/86/e2ffa604eacfbec3f430b1d850e7e04c4101eca1a5828f9ae54bf51dfba4/pypandoc-1.17-py3-none-any.whl", hash = "sha256:01fdbffa61edb9f8e82e8faad6954efcb7b6f8f0634aead4d89e322a00225a67", size = 23554, upload-time = "2026-03-14T22:38:46.007Z" }, +] + [[package]] name = "pypandoc-binary" version = "1.17" @@ -5405,11 +5340,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.8.0" +version = "6.9.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" }, + { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, ] [[package]] @@ -5467,18 +5402,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.55.0" +version = "0.57.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" }, - { url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" }, - { url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" }, - { url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" }, - { url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" }, - { url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" }, - { url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" }, + { url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" }, + { url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" }, + { url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" }, + { url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" }, + { url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" }, + { url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" }, ] [[package]] @@ -5512,16 +5447,16 @@ wheels = [ [[package]] name = "pytest-cov" -version = "7.0.0" +version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "coverage", extra = ["toml"] }, { name = "pluggy" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/51/a849f96e117386044471c8ec2bd6cfebacda285da9525c9106aeb28da671/pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2", size = 55592, upload-time = "2026-03-21T20:11:16.284Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, + { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" }, ] [[package]] @@ -5957,15 +5892,15 @@ wheels = [ [[package]] name = "resend" -version = "2.23.0" +version = "2.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/a3/20003e7d14604fef778bd30c69604df3560a657a95a5c29a9688610759b6/resend-2.23.0.tar.gz", hash = "sha256:df613827dcc40eb1c9de2e5ff600cd4081b89b206537dec8067af1a5016d23c7", size = 31416, upload-time = "2026-02-23T19:01:57.603Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/ff/6a4e5e758fc2145c6a7d8563934d8ee24bf96a0212d7ec7d1af1f155bb74/resend-2.26.0.tar.gz", hash = "sha256:957a6a59dc597ce27fbd6d5383220dd9cc497fab99d4f3d775c8a42a449a569e", size = 36238, upload-time = "2026-03-20T22:49:09.728Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/35/64df775b8cd95e89798fd7b1b7fcafa975b6b09f559c10c0650e65b33580/resend-2.23.0-py2.py3-none-any.whl", hash = "sha256:eca6d28a1ffd36c1fc489fa83cb6b511f384792c9f07465f7c92d96c8b4d5636", size = 52599, upload-time = "2026-02-23T19:01:55.962Z" }, + { url = "https://files.pythonhosted.org/packages/16/c2/f88d3299d97aa1d36a923d0846fe185fcf5355ca898c954b2e5a79f090b5/resend-2.26.0-py2.py3-none-any.whl", hash = "sha256:5e25a804a84a68df504f2ade5369ac37e0139e37788a1f20b66c88696595b4bc", size = 57699, upload-time = "2026-03-20T22:49:08.354Z" }, ] [[package]] @@ -6046,27 +5981,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.6" +version = "0.15.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/df/f8629c19c5318601d3121e230f74cbee7a3732339c52b21daa2b82ef9c7d/ruff-0.15.6.tar.gz", hash = "sha256:8394c7bb153a4e3811a4ecdacd4a8e6a4fa8097028119160dffecdcdf9b56ae4", size = 4597916, upload-time = "2026-03-12T23:05:47.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/22/9e4f66ee588588dc6c9af6a994e12d26e19efbe874d1a909d09a6dac7a59/ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac", size = 4601277, upload-time = "2026-03-19T16:26:22.605Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/2f/4e03a7e5ce99b517e98d3b4951f411de2b0fa8348d39cf446671adcce9a2/ruff-0.15.6-py3-none-linux_armv6l.whl", hash = "sha256:7c98c3b16407b2cf3d0f2b80c80187384bc92c6774d85fefa913ecd941256fff", size = 10508953, upload-time = "2026-03-12T23:05:17.246Z" }, - { url = "https://files.pythonhosted.org/packages/70/60/55bcdc3e9f80bcf39edf0cd272da6fa511a3d94d5a0dd9e0adf76ceebdb4/ruff-0.15.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ee7dcfaad8b282a284df4aa6ddc2741b3f4a18b0555d626805555a820ea181c3", size = 10942257, upload-time = "2026-03-12T23:05:23.076Z" }, - { url = "https://files.pythonhosted.org/packages/e7/f9/005c29bd1726c0f492bfa215e95154cf480574140cb5f867c797c18c790b/ruff-0.15.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3bd9967851a25f038fc8b9ae88a7fbd1b609f30349231dffaa37b6804923c4bb", size = 10322683, upload-time = "2026-03-12T23:05:33.738Z" }, - { url = "https://files.pythonhosted.org/packages/5f/74/2f861f5fd7cbb2146bddb5501450300ce41562da36d21868c69b7a828169/ruff-0.15.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13f4594b04e42cd24a41da653886b04d2ff87adbf57497ed4f728b0e8a4866f8", size = 10660986, upload-time = "2026-03-12T23:05:53.245Z" }, - { url = "https://files.pythonhosted.org/packages/c1/a1/309f2364a424eccb763cdafc49df843c282609f47fe53aa83f38272389e0/ruff-0.15.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e2ed8aea2f3fe57886d3f00ea5b8aae5bf68d5e195f487f037a955ff9fbaac9e", size = 10332177, upload-time = "2026-03-12T23:05:56.145Z" }, - { url = "https://files.pythonhosted.org/packages/30/41/7ebf1d32658b4bab20f8ac80972fb19cd4e2c6b78552be263a680edc55ac/ruff-0.15.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70789d3e7830b848b548aae96766431c0dc01a6c78c13381f423bf7076c66d15", size = 11170783, upload-time = "2026-03-12T23:06:01.742Z" }, - { url = "https://files.pythonhosted.org/packages/76/be/6d488f6adca047df82cd62c304638bcb00821c36bd4881cfca221561fdfc/ruff-0.15.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:542aaf1de3154cea088ced5a819ce872611256ffe2498e750bbae5247a8114e9", size = 12044201, upload-time = "2026-03-12T23:05:28.697Z" }, - { url = "https://files.pythonhosted.org/packages/71/68/e6f125df4af7e6d0b498f8d373274794bc5156b324e8ab4bf5c1b4fc0ec7/ruff-0.15.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c22e6f02c16cfac3888aa636e9eba857254d15bbacc9906c9689fdecb1953ab", size = 11421561, upload-time = "2026-03-12T23:05:31.236Z" }, - { url = "https://files.pythonhosted.org/packages/f1/9f/f85ef5fd01a52e0b472b26dc1b4bd228b8f6f0435975442ffa4741278703/ruff-0.15.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98893c4c0aadc8e448cfa315bd0cc343a5323d740fe5f28ef8a3f9e21b381f7e", size = 11310928, upload-time = "2026-03-12T23:05:45.288Z" }, - { url = "https://files.pythonhosted.org/packages/8c/26/b75f8c421f5654304b89471ed384ae8c7f42b4dff58fa6ce1626d7f2b59a/ruff-0.15.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:70d263770d234912374493e8cc1e7385c5d49376e41dfa51c5c3453169dc581c", size = 11235186, upload-time = "2026-03-12T23:05:50.677Z" }, - { url = "https://files.pythonhosted.org/packages/fc/d4/d5a6d065962ff7a68a86c9b4f5500f7d101a0792078de636526c0edd40da/ruff-0.15.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:55a1ad63c5a6e54b1f21b7514dfadc0c7fb40093fa22e95143cf3f64ebdcd512", size = 10635231, upload-time = "2026-03-12T23:05:37.044Z" }, - { url = "https://files.pythonhosted.org/packages/d6/56/7c3acf3d50910375349016cf33de24be021532042afbed87942858992491/ruff-0.15.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8dc473ba093c5ec238bb1e7429ee676dca24643c471e11fbaa8a857925b061c0", size = 10340357, upload-time = "2026-03-12T23:06:04.748Z" }, - { url = "https://files.pythonhosted.org/packages/06/54/6faa39e9c1033ff6a3b6e76b5df536931cd30caf64988e112bbf91ef5ce5/ruff-0.15.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:85b042377c2a5561131767974617006f99f7e13c63c111b998f29fc1e58a4cfb", size = 10860583, upload-time = "2026-03-12T23:05:58.978Z" }, - { url = "https://files.pythonhosted.org/packages/cb/1e/509a201b843b4dfb0b32acdedf68d951d3377988cae43949ba4c4133a96a/ruff-0.15.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cef49e30bc5a86a6a92098a7fbf6e467a234d90b63305d6f3ec01225a9d092e0", size = 11410976, upload-time = "2026-03-12T23:05:39.955Z" }, - { url = "https://files.pythonhosted.org/packages/6c/25/3fc9114abf979a41673ce877c08016f8e660ad6cf508c3957f537d2e9fa9/ruff-0.15.6-py3-none-win32.whl", hash = "sha256:bbf67d39832404812a2d23020dda68fee7f18ce15654e96fb1d3ad21a5fe436c", size = 10616872, upload-time = "2026-03-12T23:05:42.451Z" }, - { url = "https://files.pythonhosted.org/packages/89/7a/09ece68445ceac348df06e08bf75db72d0e8427765b96c9c0ffabc1be1d9/ruff-0.15.6-py3-none-win_amd64.whl", hash = "sha256:aee25bc84c2f1007ecb5037dff75cef00414fdf17c23f07dc13e577883dca406", size = 11787271, upload-time = "2026-03-12T23:05:20.168Z" }, - { url = "https://files.pythonhosted.org/packages/7f/d0/578c47dd68152ddddddf31cd7fc67dc30b7cdf639a86275fda821b0d9d98/ruff-0.15.6-py3-none-win_arm64.whl", hash = "sha256:c34de3dd0b0ba203be50ae70f5910b17188556630e2178fd7d79fc030eb0d837", size = 11060497, upload-time = "2026-03-12T23:05:25.968Z" }, + { url = "https://files.pythonhosted.org/packages/41/2f/0b08ced94412af091807b6119ca03755d651d3d93a242682bf020189db94/ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e", size = 10489037, upload-time = "2026-03-19T16:26:32.47Z" }, + { url = "https://files.pythonhosted.org/packages/91/4a/82e0fa632e5c8b1eba5ee86ecd929e8ff327bbdbfb3c6ac5d81631bef605/ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477", size = 10955433, upload-time = "2026-03-19T16:27:00.205Z" }, + { url = "https://files.pythonhosted.org/packages/ab/10/12586735d0ff42526ad78c049bf51d7428618c8b5c467e72508c694119df/ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e", size = 10269302, upload-time = "2026-03-19T16:26:26.183Z" }, + { url = "https://files.pythonhosted.org/packages/eb/5d/32b5c44ccf149a26623671df49cbfbd0a0ae511ff3df9d9d2426966a8d57/ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf", size = 10607625, upload-time = "2026-03-19T16:27:03.263Z" }, + { url = "https://files.pythonhosted.org/packages/5d/f1/f0001cabe86173aaacb6eb9bb734aa0605f9a6aa6fa7d43cb49cbc4af9c9/ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85", size = 10324743, upload-time = "2026-03-19T16:27:09.791Z" }, + { url = "https://files.pythonhosted.org/packages/7a/87/b8a8f3d56b8d848008559e7c9d8bf367934d5367f6d932ba779456e2f73b/ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0", size = 11138536, upload-time = "2026-03-19T16:27:06.101Z" }, + { url = "https://files.pythonhosted.org/packages/e4/f2/4fd0d05aab0c5934b2e1464784f85ba2eab9d54bffc53fb5430d1ed8b829/ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912", size = 11994292, upload-time = "2026-03-19T16:26:48.718Z" }, + { url = "https://files.pythonhosted.org/packages/64/22/fc4483871e767e5e95d1622ad83dad5ebb830f762ed0420fde7dfa9d9b08/ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036", size = 11398981, upload-time = "2026-03-19T16:26:54.513Z" }, + { url = "https://files.pythonhosted.org/packages/b0/99/66f0343176d5eab02c3f7fcd2de7a8e0dd7a41f0d982bee56cd1c24db62b/ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5", size = 11242422, upload-time = "2026-03-19T16:26:29.277Z" }, + { url = "https://files.pythonhosted.org/packages/5d/3a/a7060f145bfdcce4c987ea27788b30c60e2c81d6e9a65157ca8afe646328/ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12", size = 11232158, upload-time = "2026-03-19T16:26:42.321Z" }, + { url = "https://files.pythonhosted.org/packages/a7/53/90fbb9e08b29c048c403558d3cdd0adf2668b02ce9d50602452e187cd4af/ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c", size = 10577861, upload-time = "2026-03-19T16:26:57.459Z" }, + { url = "https://files.pythonhosted.org/packages/2f/aa/5f486226538fe4d0f0439e2da1716e1acf895e2a232b26f2459c55f8ddad/ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4", size = 10327310, upload-time = "2026-03-19T16:26:35.909Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/271afdffb81fe7bfc8c43ba079e9d96238f674380099457a74ccb3863857/ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d", size = 10840752, upload-time = "2026-03-19T16:26:45.723Z" }, + { url = "https://files.pythonhosted.org/packages/bf/29/a4ae78394f76c7759953c47884eb44de271b03a66634148d9f7d11e721bd/ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580", size = 11336961, upload-time = "2026-03-19T16:26:39.076Z" }, + { url = "https://files.pythonhosted.org/packages/26/6b/8786ba5736562220d588a2f6653e6c17e90c59ced34a2d7b512ef8956103/ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de", size = 10582538, upload-time = "2026-03-19T16:26:15.992Z" }, + { url = "https://files.pythonhosted.org/packages/2b/e9/346d4d3fffc6871125e877dae8d9a1966b254fbd92a50f8561078b88b099/ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1", size = 11755839, upload-time = "2026-03-19T16:26:19.897Z" }, + { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, ] [[package]] @@ -6105,14 +6040,14 @@ wheels = [ [[package]] name = "scipy-stubs" -version = "1.17.1.2" +version = "1.17.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "optype", extra = ["numpy"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c7/ab/43f681ffba42f363b7ed6b767fd215d1e26006578214ff8330586a11bf95/scipy_stubs-1.17.1.2.tar.gz", hash = "sha256:2ecadc8c87a3b61aaf7379d6d6b10f1038a829c53b9efe5b174fb97fc8b52237", size = 388354, upload-time = "2026-03-15T22:33:20.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/59/59c6cc3f9970154b9ed6b1aff42a0185cdd60cef54adc0404b9e77972221/scipy_stubs-1.17.1.3.tar.gz", hash = "sha256:5eb87a8d23d726706259b012ebe76a4a96a9ae9e141fc59bf55fc8eac2ed9e0f", size = 392185, upload-time = "2026-03-22T22:11:58.34Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/0b/ec4fe720c1202d9df729a3e9d9b7e4d2da9f6e7f28bd2877b7d0769f4f75/scipy_stubs-1.17.1.2-py3-none-any.whl", hash = "sha256:f19e8f5273dbe3b7ee6a9554678c3973b9695fa66b91f29206d00830a1536c06", size = 594377, upload-time = "2026-03-15T22:33:18.684Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d4/94304532c0a75a55526119043dd44a9bd1541a21e14483cbb54261c527d2/scipy_stubs-1.17.1.3-py3-none-any.whl", hash = "sha256:7b91d3f05aa47da06fbca14eb6c5bb4c28994e9245fd250cc847e375bab31297", size = 597933, upload-time = "2026-03-22T22:11:56.525Z" }, ] [[package]] @@ -6131,15 +6066,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.54.0" +version = "2.55.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c8/e9/2e3a46c304e7fa21eaa70612f60354e32699c7102eb961f67448e222ad7c/sentry_sdk-2.54.0.tar.gz", hash = "sha256:2620c2575128d009b11b20f7feb81e4e4e8ae08ec1d36cbc845705060b45cc1b", size = 413813, upload-time = "2026-03-02T15:12:41.355Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/b8/285293dc60fc198fffc3fcdbc7c6d4e646e0f74e61461c355d40faa64ceb/sentry_sdk-2.55.0.tar.gz", hash = "sha256:3774c4d8820720ca4101548131b9c162f4c9426eb7f4d24aca453012a7470f69", size = 424505, upload-time = "2026-03-17T14:15:51.707Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/53/39/be412cc86bc6247b8f69e9383d7950711bd86f8d0a4a4b0fe8fad685bc21/sentry_sdk-2.54.0-py2.py3-none-any.whl", hash = "sha256:fd74e0e281dcda63afff095d23ebcd6e97006102cdc8e78a29f19ecdf796a0de", size = 439198, upload-time = "2026-03-02T15:12:39.546Z" }, + { url = "https://files.pythonhosted.org/packages/9a/66/20465097782d7e1e742d846407ea7262d338c6e876ddddad38ca8907b38f/sentry_sdk-2.55.0-py2.py3-none-any.whl", hash = "sha256:97026981cb15699394474a196b88503a393cbc58d182ece0d3abe12b9bd978d4", size = 449284, upload-time = "2026-03-17T14:15:49.604Z" }, ] [package.optional-dependencies] @@ -6375,15 +6310,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.52.1" +version = "1.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/69/17425771797c36cded50b7fe44e850315d039f28b15901ab44839e70b593/starlette-1.0.0.tar.gz", hash = "sha256:6a4beaf1f81bb472fd19ea9b918b50dc3a77a6f2e190a12954b25e6ed5eea149", size = 2655289, upload-time = "2026-03-22T18:29:46.779Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, + { url = "https://files.pythonhosted.org/packages/0b/c9/584bc9651441b4ba60cc4d557d8a547b5aff901af35bda3a4ee30c819b82/starlette-1.0.0-py3-none-any.whl", hash = "sha256:d3ec55e0bb321692d275455ddfd3df75fff145d009685eb40dc91fc66b03d38b", size = 72651, upload-time = "2026-03-22T18:29:45.111Z" }, ] [[package]] @@ -6792,11 +6727,11 @@ wheels = [ [[package]] name = "types-cachetools" -version = "6.2.0.20251022" +version = "6.2.0.20260317" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3b/a8/f9bcc7f1be63af43ef0170a773e2d88817bcc7c9d8769f2228c802826efe/types_cachetools-6.2.0.20251022.tar.gz", hash = "sha256:f1d3c736f0f741e89ec10f0e1b0138625023e21eb33603a930c149e0318c0cef", size = 9608, upload-time = "2025-10-22T03:03:58.16Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/7f/16a4d8344c28193a5a74358028c2d2f753f0d9658dd98b9e1967c50045a2/types_cachetools-6.2.0.20260317.tar.gz", hash = "sha256:6d91855bcc944665897c125e720aa3c80aace929b77a64e796343701df4f61c6", size = 9812, upload-time = "2026-03-17T04:06:32.007Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/98/2d/8d821ed80f6c2c5b427f650bf4dc25b80676ed63d03388e4b637d2557107/types_cachetools-6.2.0.20251022-py3-none-any.whl", hash = "sha256:698eb17b8f16b661b90624708b6915f33dbac2d185db499ed57e4997e7962cad", size = 9341, upload-time = "2025-10-22T03:03:57.036Z" }, + { url = "https://files.pythonhosted.org/packages/17/9a/b00b23054934c4d569c19f7278c4fb32746cd36a64a175a216d3073a4713/types_cachetools-6.2.0.20260317-py3-none-any.whl", hash = "sha256:92fa9bc50e4629e31fca67ceb3fb1de71791e314fa16c0a0d2728724dc222c8b", size = 9346, upload-time = "2026-03-17T04:06:31.184Z" }, ] [[package]] @@ -6840,11 +6775,11 @@ wheels = [ [[package]] name = "types-docutils" -version = "0.22.3.20260316" +version = "0.22.3.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9f/27/a7f16b3a2fad0a4ddd85a668319f9a1d0311c4bd9578894f6471c7e6c788/types_docutils-0.22.3.20260316.tar.gz", hash = "sha256:8ef27d565b9831ff094fe2eac75337a74151013e2d21ecabd445c2955f891564", size = 57263, upload-time = "2026-03-16T04:29:12.211Z" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bb/243a87fc1605a4a94c2c343d6dbddbf0d7ef7c0b9550f360b8cda8e82c39/types_docutils-0.22.3.20260322.tar.gz", hash = "sha256:e2450bb997283c3141ec5db3e436b91f0aa26efe35eb9165178ca976ccb4930b", size = 57311, upload-time = "2026-03-22T04:08:44.064Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/60/c1f22b7cfc4837d5419e5a2d8702c7d65f03343f866364b71cccd8a73b79/types_docutils-0.22.3.20260316-py3-none-any.whl", hash = "sha256:083c7091b8072c242998ec51da1bf1492f0332387da81c3b085efbf5ca754c7d", size = 91968, upload-time = "2026-03-16T04:29:11.114Z" }, + { url = "https://files.pythonhosted.org/packages/c6/4a/22c090cd4615a16917dff817cbe7c5956da376c961e024c241cd962d2c3d/types_docutils-0.22.3.20260322-py3-none-any.whl", hash = "sha256:681d4510ce9b80a0c6a593f0f9843d81f8caa786db7b39ba04d9fd5480ac4442", size = 91978, upload-time = "2026-03-22T04:08:43.117Z" }, ] [[package]] @@ -6874,15 +6809,15 @@ wheels = [ [[package]] name = "types-gevent" -version = "25.9.0.20251228" +version = "25.9.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-greenlet" }, { name = "types-psutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/85/c5043c4472f82c8ee3d9e0673eb4093c7d16770a26541a137a53a1d096f6/types_gevent-25.9.0.20251228.tar.gz", hash = "sha256:423ef9891d25c5a3af236c3e9aace4c444c86ff773fe13ef22731bc61d59abef", size = 38063, upload-time = "2025-12-28T03:28:28.651Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f0/14a99ddcaa69b559fa7cec8c9de880b792bebb0b848ae865d94ea9058533/types_gevent-25.9.0.20260322.tar.gz", hash = "sha256:91257920845762f09753c08aa20fad1743ac13d2de8bcf23f4b8fe967d803732", size = 38241, upload-time = "2026-03-22T04:08:55.213Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/b7/a2d6b652ab5a26318b68cafd58c46fafb9b15c5313d2d76a70b838febb4b/types_gevent-25.9.0.20251228-py3-none-any.whl", hash = "sha256:e2e225af4fface9241c16044983eb2fc3993f2d13d801f55c2932848649b7f2f", size = 55486, upload-time = "2025-12-28T03:28:27.382Z" }, + { url = "https://files.pythonhosted.org/packages/89/0f/964440b57eb4ddb4aca03479a4093852e1ce79010d1c5967234e6f5d6bd9/types_gevent-25.9.0.20260322-py3-none-any.whl", hash = "sha256:21b3c269b3a20ecb0e4668289c63b97d21694d84a004ab059c1e32ab970eacc2", size = 55500, upload-time = "2026-03-22T04:08:54.103Z" }, ] [[package]] @@ -6965,11 +6900,11 @@ wheels = [ [[package]] name = "types-openpyxl" -version = "3.1.5.20260316" +version = "3.1.5.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/38/32f8ee633dd66ca6d52b8853b9fd45dc3869490195a6ed435d5c868b9c2d/types_openpyxl-3.1.5.20260316.tar.gz", hash = "sha256:081dda9427ea1141e5649e3dcf630e7013a4cf254a5862a7e0a3f53c123b7ceb", size = 101318, upload-time = "2026-03-16T04:29:05.004Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/bf/15240de4d68192d2a1f385ef2f6f1ecb29b85d2f3791dd2e2d5b980be30f/types_openpyxl-3.1.5.20260322.tar.gz", hash = "sha256:a61d66ebe1e49697853c6db8e0929e1cda2c96755e71fb676ed7fc48dfdcf697", size = 101325, upload-time = "2026-03-22T04:08:40.426Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/df/b87ae6226ed7cc84b9e43119c489c7f053a9a25e209e0ebb5d84bc36fa37/types_openpyxl-3.1.5.20260316-py3-none-any.whl", hash = "sha256:38e7e125df520fb7eb72cb1129c9f024eb99ef9564aad2c27f68f080c26bcf2d", size = 166084, upload-time = "2026-03-16T04:29:03.657Z" }, + { url = "https://files.pythonhosted.org/packages/bf/b4/c14191b30bcb266365b124b2bb4e67ecd68425a78ba77ee026f33667daa9/types_openpyxl-3.1.5.20260322-py3-none-any.whl", hash = "sha256:2f515f0b0bbfb04bfb587de34f7522d90b5151a8da7bbbd11ecec4ca40f64238", size = 166102, upload-time = "2026-03-22T04:08:39.174Z" }, ] [[package]] @@ -7044,11 +6979,11 @@ wheels = [ [[package]] name = "types-python-dateutil" -version = "2.9.0.20260305" +version = "2.9.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/c7/025c624f347e10476b439a6619a95f1d200250ea88e7ccea6e09e48a7544/types_python_dateutil-2.9.0.20260305.tar.gz", hash = "sha256:389717c9f64d8f769f36d55a01873915b37e97e52ce21928198d210fbd393c8b", size = 16885, upload-time = "2026-03-05T04:00:47.409Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/02/f72df9ef5ffc4f959b83cb80c8aa03eb8718a43e563ecd99ccffe265fa89/types_python_dateutil-2.9.0.20260323.tar.gz", hash = "sha256:a107aef5841db41ace381dbbbd7e4945220fc940f7a72172a0be5a92d9ab7164", size = 16897, upload-time = "2026-03-23T04:15:14.829Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/77/8c0d1ec97f0d9707ad3d8fa270ab8964e7b31b076d2f641c94987395cc75/types_python_dateutil-2.9.0.20260305-py3-none-any.whl", hash = "sha256:a3be9ca444d38cadabd756cfbb29780d8b338ae2a3020e73c266a83cc3025dd7", size = 18419, upload-time = "2026-03-05T04:00:46.392Z" }, + { url = "https://files.pythonhosted.org/packages/92/c1/b661838b97453e699a215451f2e22cee750eaaf4ea4619b34bdaf01221a4/types_python_dateutil-2.9.0.20260323-py3-none-any.whl", hash = "sha256:a23a50a07f6eb87e729d4cb0c2eb511c81761eeb3f505db2c1413be94aae8335", size = 18433, upload-time = "2026-03-23T04:15:13.683Z" }, ] [[package]] @@ -7062,11 +6997,11 @@ wheels = [ [[package]] name = "types-pywin32" -version = "311.0.0.20260316" +version = "311.0.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/17/a8/b4652002a854fcfe5d272872a0ae2d5df0e9dc482e1a6dfb5e97b905b76f/types_pywin32-311.0.0.20260316.tar.gz", hash = "sha256:c136fa489fe6279a13bca167b750414e18d657169b7cf398025856dc363004e8", size = 329956, upload-time = "2026-03-16T04:28:57.366Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/cc/f03ddb7412ac2fc2238358b617c2d5919ba96812dff8d3081f3b2754bb83/types_pywin32-311.0.0.20260323.tar.gz", hash = "sha256:2e8dc6a59fedccbc51b241651ce1e8aa58488934f517debf23a9c6d0ff329b4b", size = 332263, upload-time = "2026-03-23T04:15:20.004Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/83/704698d93788cf1c2f5e236eae2b37f1b2152ef84dc66b4b83f6c7487b76/types_pywin32-311.0.0.20260316-py3-none-any.whl", hash = "sha256:abb643d50012386d697af49384cc0e6e475eab76b0ca2a7f93d480d0862b3692", size = 392959, upload-time = "2026-03-16T04:28:56.104Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/d786d5d8b846e3cbe1ee52da8945560b111c789b42c3771b2129b312ab94/types_pywin32-311.0.0.20260323-py3-none-any.whl", hash = "sha256:2f2b03fc72ae77ccbb0ee258da0f181c3a38bd8602f6e332e42587b3b0d5f095", size = 395435, upload-time = "2026-03-23T04:15:18.76Z" }, ] [[package]] @@ -7162,16 +7097,16 @@ wheels = [ [[package]] name = "types-tensorflow" -version = "2.18.0.20260224" +version = "2.18.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "types-protobuf" }, { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/cb/4914c2fbc1cf8a8d1ef2a7c727bb6f694879be85edeee880a0c88e696af8/types_tensorflow-2.18.0.20260224.tar.gz", hash = "sha256:9b0ccc91c79c88791e43d3f80d6c879748fa0361409c5ff23c7ffe3709be00f2", size = 258786, upload-time = "2026-02-24T04:06:45.613Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/81dfaa2680031a6e087bcdfaf1c0556371098e229aee541e21c81a381065/types_tensorflow-2.18.0.20260322.tar.gz", hash = "sha256:135dc6ca06cc647a002e1bca5c5c99516fde51efd08e46c48a9b1916fc5df07f", size = 259030, upload-time = "2026-03-22T04:09:14.069Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/1d/a1c3c60f0eb1a204500dbdc66e3d18aafabc86ad07a8eca71ea05bc8c5a8/types_tensorflow-2.18.0.20260224-py3-none-any.whl", hash = "sha256:6a25f5f41f3e06f28c1f65c6e09f484d4ba0031d6d8df83a39df9d890245eefc", size = 329746, upload-time = "2026-02-24T04:06:44.4Z" }, + { url = "https://files.pythonhosted.org/packages/5b/0c/a178061450b640e53577e2c423ad22bf5d3f692f6bfeeb12156d02b531ef/types_tensorflow-2.18.0.20260322-py3-none-any.whl", hash = "sha256:d8776b6daacdb279e64f105f9dcbc0b8e3544b9a2f2eb71ec6ea5955081f65e6", size = 329771, upload-time = "2026-03-22T04:09:12.844Z" }, ] [[package]] @@ -7248,30 +7183,43 @@ wheels = [ [[package]] name = "ujson" -version = "5.9.0" +version = "5.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6e/54/6f2bdac7117e89a47de4511c9f01732a283457ab1bf856e1e51aa861619e/ujson-5.9.0.tar.gz", hash = "sha256:89cc92e73d5501b8a7f48575eeb14ad27156ad092c2e9fc7e3cf949f07e75532", size = 7154214, upload-time = "2023-12-10T22:50:34.812Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/3e/c35530c5ffc25b71c59ae0cd7b8f99df37313daa162ce1e2f7925f7c2877/ujson-5.12.0.tar.gz", hash = "sha256:14b2e1eb528d77bc0f4c5bd1a7ebc05e02b5b41beefb7e8567c9675b8b13bcf4", size = 7158451, upload-time = "2026-03-11T22:19:30.397Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/ca/ae3a6ca5b4f82ce654d6ac3dde5e59520537e20939592061ba506f4e569a/ujson-5.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b23bbb46334ce51ddb5dded60c662fbf7bb74a37b8f87221c5b0fec1ec6454b", size = 57753, upload-time = "2023-12-10T22:49:03.939Z" }, - { url = "https://files.pythonhosted.org/packages/34/5f/c27fa9a1562c96d978c39852b48063c3ca480758f3088dcfc0f3b09f8e93/ujson-5.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6974b3a7c17bbf829e6c3bfdc5823c67922e44ff169851a755eab79a3dd31ec0", size = 54092, upload-time = "2023-12-10T22:49:05.194Z" }, - { url = "https://files.pythonhosted.org/packages/19/f3/1431713de9e5992e5e33ba459b4de28f83904233958855d27da820a101f9/ujson-5.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5964ea916edfe24af1f4cc68488448fbb1ec27a3ddcddc2b236da575c12c8ae", size = 51675, upload-time = "2023-12-10T22:49:06.449Z" }, - { url = "https://files.pythonhosted.org/packages/d3/93/de6fff3ae06351f3b1c372f675fe69bc180f93d237c9e496c05802173dd6/ujson-5.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba7cac47dd65ff88571eceeff48bf30ed5eb9c67b34b88cb22869b7aa19600d", size = 53246, upload-time = "2023-12-10T22:49:07.691Z" }, - { url = "https://files.pythonhosted.org/packages/26/73/db509fe1d7da62a15c0769c398cec66bdfc61a8bdffaf7dfa9d973e3d65c/ujson-5.9.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6bbd91a151a8f3358c29355a491e915eb203f607267a25e6ab10531b3b157c5e", size = 58182, upload-time = "2023-12-10T22:49:08.89Z" }, - { url = "https://files.pythonhosted.org/packages/fc/a8/6be607fa3e1fa3e1c9b53f5de5acad33b073b6cc9145803e00bcafa729a8/ujson-5.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:829a69d451a49c0de14a9fecb2a2d544a9b2c884c2b542adb243b683a6f15908", size = 584493, upload-time = "2023-12-10T22:49:11.043Z" }, - { url = "https://files.pythonhosted.org/packages/c8/c7/33822c2f1a8175e841e2bc378ffb2c1109ce9280f14cedb1b2fa0caf3145/ujson-5.9.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a807ae73c46ad5db161a7e883eec0fbe1bebc6a54890152ccc63072c4884823b", size = 656038, upload-time = "2023-12-10T22:49:12.651Z" }, - { url = "https://files.pythonhosted.org/packages/51/b8/5309fbb299d5fcac12bbf3db20896db5178392904abe6b992da233dc69d6/ujson-5.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8fc2aa18b13d97b3c8ccecdf1a3c405f411a6e96adeee94233058c44ff92617d", size = 597643, upload-time = "2023-12-10T22:49:14.883Z" }, - { url = "https://files.pythonhosted.org/packages/5f/64/7b63043b95dd78feed401b9973958af62645a6d19b72b6e83d1ea5af07e0/ujson-5.9.0-cp311-cp311-win32.whl", hash = "sha256:70e06849dfeb2548be48fdd3ceb53300640bc8100c379d6e19d78045e9c26120", size = 38342, upload-time = "2023-12-10T22:49:16.854Z" }, - { url = "https://files.pythonhosted.org/packages/7a/13/a3cd1fc3a1126d30b558b6235c05e2d26eeaacba4979ee2fd2b5745c136d/ujson-5.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:7309d063cd392811acc49b5016728a5e1b46ab9907d321ebbe1c2156bc3c0b99", size = 41923, upload-time = "2023-12-10T22:49:17.983Z" }, - { url = "https://files.pythonhosted.org/packages/16/7e/c37fca6cd924931fa62d615cdbf5921f34481085705271696eff38b38867/ujson-5.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:20509a8c9f775b3a511e308bbe0b72897ba6b800767a7c90c5cca59d20d7c42c", size = 57834, upload-time = "2023-12-10T22:49:19.799Z" }, - { url = "https://files.pythonhosted.org/packages/fb/44/2753e902ee19bf6ccaf0bda02f1f0037f92a9769a5d31319905e3de645b4/ujson-5.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b28407cfe315bd1b34f1ebe65d3bd735d6b36d409b334100be8cdffae2177b2f", size = 54119, upload-time = "2023-12-10T22:49:21.039Z" }, - { url = "https://files.pythonhosted.org/packages/d2/06/2317433e394450bc44afe32b6c39d5a51014da4c6f6cfc2ae7bf7b4a2922/ujson-5.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d302bd17989b6bd90d49bade66943c78f9e3670407dbc53ebcf61271cadc399", size = 51658, upload-time = "2023-12-10T22:49:22.494Z" }, - { url = "https://files.pythonhosted.org/packages/5b/3a/2acf0da085d96953580b46941504aa3c91a1dd38701b9e9bfa43e2803467/ujson-5.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f21315f51e0db8ee245e33a649dd2d9dce0594522de6f278d62f15f998e050e", size = 53370, upload-time = "2023-12-10T22:49:24.045Z" }, - { url = "https://files.pythonhosted.org/packages/03/32/737e6c4b1841720f88ae88ec91f582dc21174bd40742739e1fa16a0c9ffa/ujson-5.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5635b78b636a54a86fdbf6f027e461aa6c6b948363bdf8d4fbb56a42b7388320", size = 58278, upload-time = "2023-12-10T22:49:25.261Z" }, - { url = "https://files.pythonhosted.org/packages/8a/dc/3fda97f1ad070ccf2af597fb67dde358bc698ffecebe3bc77991d60e4fe5/ujson-5.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82b5a56609f1235d72835ee109163c7041b30920d70fe7dac9176c64df87c164", size = 584418, upload-time = "2023-12-10T22:49:27.573Z" }, - { url = "https://files.pythonhosted.org/packages/d7/57/e4083d774fcd8ff3089c0ff19c424abe33f23e72c6578a8172bf65131992/ujson-5.9.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5ca35f484622fd208f55041b042d9d94f3b2c9c5add4e9af5ee9946d2d30db01", size = 656126, upload-time = "2023-12-10T22:49:29.509Z" }, - { url = "https://files.pythonhosted.org/packages/0d/c3/8c6d5f6506ca9fcedd5a211e30a7d5ee053dc05caf23dae650e1f897effb/ujson-5.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:829b824953ebad76d46e4ae709e940bb229e8999e40881338b3cc94c771b876c", size = 597795, upload-time = "2023-12-10T22:49:31.029Z" }, - { url = "https://files.pythonhosted.org/packages/34/5a/a231f0cd305a34cf2d16930304132db3a7a8c3997b367dd38fc8f8dfae36/ujson-5.9.0-cp312-cp312-win32.whl", hash = "sha256:25fa46e4ff0a2deecbcf7100af3a5d70090b461906f2299506485ff31d9ec437", size = 38495, upload-time = "2023-12-10T22:49:33.2Z" }, - { url = "https://files.pythonhosted.org/packages/30/b7/18b841b44760ed298acdb150608dccdc045c41655e0bae4441f29bcab872/ujson-5.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:60718f1720a61560618eff3b56fd517d107518d3c0160ca7a5a66ac949c6cf1c", size = 42088, upload-time = "2023-12-10T22:49:34.921Z" }, + { url = "https://files.pythonhosted.org/packages/10/22/fd22e2f6766bae934d3050517ca47d463016bd8688508d1ecc1baa18a7ad/ujson-5.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58a11cb49482f1a095a2bd9a1d81dd7c8fb5d2357f959ece85db4e46a825fd00", size = 56139, upload-time = "2026-03-11T22:18:04.591Z" }, + { url = "https://files.pythonhosted.org/packages/c6/fd/6839adff4fc0164cbcecafa2857ba08a6eaeedd7e098d6713cb899a91383/ujson-5.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9b3cf13facf6f77c283af0e1713e5e8c47a0fe295af81326cb3cb4380212e797", size = 53836, upload-time = "2026-03-11T22:18:05.662Z" }, + { url = "https://files.pythonhosted.org/packages/f9/b0/0c19faac62d68ceeffa83a08dc3d71b8462cf5064d0e7e0b15ba19898dad/ujson-5.12.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb94245a715b4d6e24689de12772b85329a1f9946cbf6187923a64ecdea39e65", size = 57851, upload-time = "2026-03-11T22:18:06.744Z" }, + { url = "https://files.pythonhosted.org/packages/04/f6/e7fd283788de73b86e99e08256726bb385923249c21dcd306e59d532a1a1/ujson-5.12.0-cp311-cp311-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:0fe6b8b8968e11dd9b2348bd508f0f57cf49ab3512064b36bc4117328218718e", size = 59906, upload-time = "2026-03-11T22:18:07.791Z" }, + { url = "https://files.pythonhosted.org/packages/d7/3a/b100735a2b43ee6e8fe4c883768e362f53576f964d4ea841991060aeaf35/ujson-5.12.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:89e302abd3749f6d6699691747969a5d85f7c73081d5ed7e2624c7bd9721a2ab", size = 57409, upload-time = "2026-03-11T22:18:08.79Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fa/f97cc20c99ca304662191b883ae13ae02912ca7244710016ba0cb8a5be34/ujson-5.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0727363b05ab05ee737a28f6200dc4078bce6b0508e10bd8aab507995a15df61", size = 1037339, upload-time = "2026-03-11T22:18:10.424Z" }, + { url = "https://files.pythonhosted.org/packages/10/7a/53ddeda0ffe1420db2f9999897b3cbb920fbcff1849d1f22b196d0f34785/ujson-5.12.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b62cb9a7501e1f5c9ffe190485501349c33e8862dde4377df774e40b8166871f", size = 1196625, upload-time = "2026-03-11T22:18:11.82Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1a/4c64a6bef522e9baf195dd5be151bc815cd4896c50c6e2489599edcda85f/ujson-5.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a6ec5bf6bc361f2f0f9644907a36ce527715b488988a8df534120e5c34eeda94", size = 1089669, upload-time = "2026-03-11T22:18:13.343Z" }, + { url = "https://files.pythonhosted.org/packages/18/11/8ccb109f5777ec0d9fb826695a9e2ac36ae94c1949fc8b1e4d23a5bd067a/ujson-5.12.0-cp311-cp311-win32.whl", hash = "sha256:006428d3813b87477d72d306c40c09f898a41b968e57b15a7d88454ecc42a3fb", size = 39648, upload-time = "2026-03-11T22:18:14.785Z" }, + { url = "https://files.pythonhosted.org/packages/6f/e3/87fc4c27b20d5125cff7ce52d17ea7698b22b74426da0df238e3efcb0cf2/ujson-5.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:40aa43a7a3a8d2f05e79900858053d697a88a605e3887be178b43acbcd781161", size = 43876, upload-time = "2026-03-11T22:18:15.768Z" }, + { url = "https://files.pythonhosted.org/packages/9e/21/324f0548a8c8c48e3e222eaed15fb6d48c796593002b206b4a28a89e445f/ujson-5.12.0-cp311-cp311-win_arm64.whl", hash = "sha256:561f89cc82deeae82e37d4a4764184926fb432f740a9691563a391b13f7339a4", size = 38553, upload-time = "2026-03-11T22:18:17.251Z" }, + { url = "https://files.pythonhosted.org/packages/84/f6/ac763d2108d28f3a40bb3ae7d2fafab52ca31b36c2908a4ad02cd3ceba2a/ujson-5.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:09b4beff9cc91d445d5818632907b85fb06943b61cb346919ce202668bf6794a", size = 56326, upload-time = "2026-03-11T22:18:18.467Z" }, + { url = "https://files.pythonhosted.org/packages/25/46/d0b3af64dcdc549f9996521c8be6d860ac843a18a190ffc8affeb7259687/ujson-5.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ca0c7ce828bb76ab78b3991904b477c2fd0f711d7815c252d1ef28ff9450b052", size = 53910, upload-time = "2026-03-11T22:18:19.502Z" }, + { url = "https://files.pythonhosted.org/packages/9a/10/853c723bcabc3e9825a079019055fc99e71b85c6bae600607a2b9d31d18d/ujson-5.12.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2d79c6635ccffcbfc1d5c045874ba36b594589be81d50d43472570bb8de9c57", size = 57754, upload-time = "2026-03-11T22:18:20.874Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c6/6e024830d988f521f144ead641981c1f7a82c17ad1927c22de3242565f5c/ujson-5.12.0-cp312-cp312-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:7e07f6f644d2c44d53b7a320a084eef98063651912c1b9449b5f45fcbdc6ccd2", size = 59936, upload-time = "2026-03-11T22:18:21.924Z" }, + { url = "https://files.pythonhosted.org/packages/34/c9/c5f236af5abe06b720b40b88819d00d10182d2247b1664e487b3ed9229cf/ujson-5.12.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:085b6ce182cdd6657481c7c4003a417e0655c4f6e58b76f26ee18f0ae21db827", size = 57463, upload-time = "2026-03-11T22:18:22.924Z" }, + { url = "https://files.pythonhosted.org/packages/ae/04/41342d9ef68e793a87d84e4531a150c2b682f3bcedfe59a7a5e3f73e9213/ujson-5.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:16b4fe9c97dc605f5e1887a9e1224287291e35c56cbc379f8aa44b6b7bcfe2bb", size = 1037239, upload-time = "2026-03-11T22:18:24.04Z" }, + { url = "https://files.pythonhosted.org/packages/d4/81/dc2b7617d5812670d4ff4a42f6dd77926430ee52df0dedb2aec7990b2034/ujson-5.12.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0d2e8db5ade3736a163906154ca686203acc7d1d30736cbf577c730d13653d84", size = 1196713, upload-time = "2026-03-11T22:18:25.391Z" }, + { url = "https://files.pythonhosted.org/packages/b6/9c/80acff0504f92459ed69e80a176286e32ca0147ac6a8252cd0659aad3227/ujson-5.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:93bc91fdadcf046da37a214eaa714574e7e9b1913568e93bb09527b2ceb7f759", size = 1089742, upload-time = "2026-03-11T22:18:26.738Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f0/123ffaac17e45ef2b915e3e3303f8f4ea78bb8d42afad828844e08622b1e/ujson-5.12.0-cp312-cp312-win32.whl", hash = "sha256:2a248750abce1c76fbd11b2e1d88b95401e72819295c3b851ec73399d6849b3d", size = 39773, upload-time = "2026-03-11T22:18:28.244Z" }, + { url = "https://files.pythonhosted.org/packages/b5/20/f3bd2b069c242c2b22a69e033bfe224d1d15d3649e6cd7cc7085bb1412ff/ujson-5.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:1b5c6ceb65fecd28a1d20d1eba9dbfa992612b86594e4b6d47bb580d2dd6bcb3", size = 44040, upload-time = "2026-03-11T22:18:29.236Z" }, + { url = "https://files.pythonhosted.org/packages/f0/a7/01b5a0bcded14cd2522b218f2edc3533b0fcbccdea01f3e14a2b699071aa/ujson-5.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:9a5fcbe7b949f2e95c47ea8a80b410fcdf2da61c98553b45a4ee875580418b68", size = 38526, upload-time = "2026-03-11T22:18:30.551Z" }, + { url = "https://files.pythonhosted.org/packages/95/3c/5ee154d505d1aad2debc4ba38b1a60ae1949b26cdb5fa070e85e320d6b64/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_10_13_x86_64.whl", hash = "sha256:bf85a00ac3b56a1e7a19c5be7b02b5180a0895ac4d3c234d717a55e86960691c", size = 54494, upload-time = "2026-03-11T22:19:13.035Z" }, + { url = "https://files.pythonhosted.org/packages/ce/b3/9496ec399ec921e434a93b340bd5052999030b7ac364be4cbe5365ac6b20/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:64df53eef4ac857eb5816a56e2885ccf0d7dff6333c94065c93b39c51063e01d", size = 57999, upload-time = "2026-03-11T22:19:14.385Z" }, + { url = "https://files.pythonhosted.org/packages/0e/da/e9ae98133336e7c0d50b43626c3f2327937cecfa354d844e02ac17379ed1/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c0aed6a4439994c9666fb8a5b6c4eac94d4ef6ddc95f9b806a599ef83547e3b", size = 54518, upload-time = "2026-03-11T22:19:15.4Z" }, + { url = "https://files.pythonhosted.org/packages/58/10/978d89dded6bb1558cd46ba78f4351198bd2346db8a8ee1a94119022ce40/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efae5df7a8cc8bdb1037b0f786b044ce281081441df5418c3a0f0e1f86fe7bb3", size = 55736, upload-time = "2026-03-11T22:19:16.496Z" }, + { url = "https://files.pythonhosted.org/packages/80/25/1df8e6217c92e57a1266bf5be750b1dddc126ee96e53fe959d5693503bc6/ujson-5.12.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:8712b61eb1b74a4478cfd1c54f576056199e9f093659334aeb5c4a6b385338e5", size = 44615, upload-time = "2026-03-11T22:19:17.53Z" }, + { url = "https://files.pythonhosted.org/packages/19/fa/f4a957dddb99bd68c8be91928c0b6fefa7aa8aafc92c93f5d1e8b32f6702/ujson-5.12.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:871c0e5102e47995b0e37e8df7819a894a6c3da0d097545cd1f9f1f7d7079927", size = 52145, upload-time = "2026-03-11T22:19:18.566Z" }, + { url = "https://files.pythonhosted.org/packages/55/6e/50b5cf612de1ca06c7effdc5a5d7e815774dee85a5858f1882c425553b82/ujson-5.12.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:56ba3f7abbd6b0bb282a544dc38406d1a188d8bb9164f49fdb9c2fee62cb29da", size = 49577, upload-time = "2026-03-11T22:19:19.627Z" }, + { url = "https://files.pythonhosted.org/packages/6e/24/b6713fa9897774502cd4c2d6955bb4933349f7d84c3aa805531c382a4209/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c5a52987a990eb1bae55f9000994f1afdb0326c154fb089992f839ab3c30688", size = 50807, upload-time = "2026-03-11T22:19:20.778Z" }, + { url = "https://files.pythonhosted.org/packages/1f/b6/c0e0f7901180ef80d16f3a4bccb5dc8b01515a717336a62928963a07b80b/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:adf28d13a33f9d750fe7a78fb481cac298fa257d8863d8727b2ea4455ea41235", size = 56972, upload-time = "2026-03-11T22:19:21.84Z" }, + { url = "https://files.pythonhosted.org/packages/02/a9/05d91b4295ea7239151eb08cf240e5a2ba969012fda50bc27bcb1ea9cd71/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51acc750ec7a2df786cdc868fb16fa04abd6269a01d58cf59bafc57978773d8e", size = 52045, upload-time = "2026-03-11T22:19:22.879Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7a/92047d32bf6f2d9db64605fc32e8eb0e0dd68b671eaafc12a464f69c4af4/ujson-5.12.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:ab9056d94e5db513d9313b34394f3a3b83e6301a581c28ad67773434f3faccab", size = 44053, upload-time = "2026-03-11T22:19:23.918Z" }, ] [[package]] diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000..54ac2a4b36 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,16 @@ +coverage: + status: + project: + default: + target: auto + +flags: + web: + paths: + - "web/" + carryforward: true + + api: + paths: + - "api/" + carryforward: true diff --git a/docker/.env.example b/docker/.env.example index 9d6cd65318..8cf77cf56b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -771,6 +771,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # VikingDB configurations, only available when VECTOR_STORE is `vikingdb` VIKINGDB_ACCESS_KEY=your-ak diff --git a/docker/dify-env-sync.py b/docker/dify-env-sync.py new file mode 100755 index 0000000000..d7c762748c --- /dev/null +++ b/docker/dify-env-sync.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 + +# ================================================================ +# Dify Environment Variables Synchronization Script +# +# Features: +# - Synchronize latest settings from .env.example to .env +# - Preserve custom settings in existing .env +# - Add new environment variables +# - Detect removed environment variables +# - Create backup files +# ================================================================ + +import argparse +import re +import shutil +import sys +from datetime import datetime +from pathlib import Path + +# ANSI color codes +RED = "\033[0;31m" +GREEN = "\033[0;32m" +YELLOW = "\033[1;33m" +BLUE = "\033[0;34m" +NC = "\033[0m" # No Color + + +def supports_color() -> bool: + """Return True if the terminal supports ANSI color codes.""" + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + + +def log_info(message: str) -> None: + """Print an informational message in blue.""" + if supports_color(): + print(f"{BLUE}[INFO]{NC} {message}") + else: + print(f"[INFO] {message}") + + +def log_success(message: str) -> None: + """Print a success message in green.""" + if supports_color(): + print(f"{GREEN}[SUCCESS]{NC} {message}") + else: + print(f"[SUCCESS] {message}") + + +def log_warning(message: str) -> None: + """Print a warning message in yellow to stderr.""" + if supports_color(): + print(f"{YELLOW}[WARNING]{NC} {message}", file=sys.stderr) + else: + print(f"[WARNING] {message}", file=sys.stderr) + + +def log_error(message: str) -> None: + """Print an error message in red to stderr.""" + if supports_color(): + print(f"{RED}[ERROR]{NC} {message}", file=sys.stderr) + else: + print(f"[ERROR] {message}", file=sys.stderr) + + +def parse_env_file(path: Path) -> dict[str, str]: + """Parse an .env-style file and return a mapping of key to raw value. + + Lines that are blank or start with '#' (after optional whitespace) are + skipped. Only lines containing '=' are considered variable definitions. + + Args: + path: Path to the .env file to parse. + + Returns: + Ordered dict mapping variable name to its value string. + """ + variables: dict[str, str] = {} + with path.open(encoding="utf-8") as fh: + for line in fh: + line = line.rstrip("\n") + # Skip blank lines and comment lines + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + if key: + variables[key] = value.strip() + return variables + + +def check_files(work_dir: Path) -> None: + """Verify required files exist; create .env from .env.example if absent. + + Args: + work_dir: Directory that must contain .env.example (and optionally .env). + + Raises: + SystemExit: If .env.example does not exist. + """ + log_info("Checking required files...") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + if not example_file.exists(): + log_error(".env.example file not found") + sys.exit(1) + + if not env_file.exists(): + log_warning(".env file does not exist. Creating from .env.example.") + shutil.copy2(example_file, env_file) + log_success(".env file created") + + log_success("Required files verified") + + +def create_backup(work_dir: Path) -> None: + """Create a timestamped backup of the current .env file. + + Backups are placed in ``/env-backup/`` with the filename + ``.env.backup_``. + + Args: + work_dir: Directory containing the .env file to back up. + """ + env_file = work_dir / ".env" + if not env_file.exists(): + return + + backup_dir = work_dir / "env-backup" + if not backup_dir.exists(): + backup_dir.mkdir(parents=True) + log_info(f"Created backup directory: {backup_dir}") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = backup_dir / f".env.backup_{timestamp}" + shutil.copy2(env_file, backup_file) + log_success(f"Backed up existing .env to {backup_file}") + + +def analyze_value_change(current: str, recommended: str) -> str | None: + """Analyse what kind of change occurred between two env values. + + Args: + current: Value currently set in .env. + recommended: Value present in .env.example. + + Returns: + A human-readable description string, or None when no analysis applies. + """ + use_colors = supports_color() + + def colorize(color: str, text: str) -> str: + return f"{color}{text}{NC}" if use_colors else text + + if not current and recommended: + return colorize(RED, " -> Setting from empty to recommended value") + if current and not recommended: + return colorize(RED, " -> Recommended value changed to empty") + + # Numeric comparison + if re.fullmatch(r"\d+", current) and re.fullmatch(r"\d+", recommended): + cur_int, rec_int = int(current), int(recommended) + if cur_int < rec_int: + return colorize(BLUE, f" -> Numeric increase ({current} < {recommended})") + if cur_int > rec_int: + return colorize(YELLOW, f" -> Numeric decrease ({current} > {recommended})") + return None + + # Boolean comparison + if current.lower() in {"true", "false"} and recommended.lower() in {"true", "false"}: + if current.lower() != recommended.lower(): + return colorize(BLUE, f" -> Boolean value change ({current} -> {recommended})") + return None + + # URL / endpoint + if current.startswith(("http://", "https://")) or recommended.startswith(("http://", "https://")): + return colorize(BLUE, " -> URL/endpoint change") + + # File path + if current.startswith("/") or recommended.startswith("/"): + return colorize(BLUE, " -> File path change") + + # String length + if len(current) != len(recommended): + return colorize(YELLOW, f" -> String length change ({len(current)} -> {len(recommended)} characters)") + + return None + + +def detect_differences(env_vars: dict[str, str], example_vars: dict[str, str]) -> dict[str, tuple[str, str]]: + """Find variables whose values differ between .env and .env.example. + + Only variables present in *both* files are compared; new or removed + variables are handled by separate functions. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Mapping of key -> (env_value, example_value) for every key whose + values differ. + """ + log_info("Detecting differences between .env and .env.example...") + + diffs: dict[str, tuple[str, str]] = {} + for key, example_value in example_vars.items(): + if key in env_vars and env_vars[key] != example_value: + diffs[key] = (env_vars[key], example_value) + + if diffs: + log_success(f"Detected differences in {len(diffs)} environment variables") + show_differences_detail(diffs) + else: + log_info("No differences detected") + + return diffs + + +def show_differences_detail(diffs: dict[str, tuple[str, str]]) -> None: + """Print a formatted table of differing environment variables. + + Args: + diffs: Mapping of key -> (current_value, recommended_value). + """ + use_colors = supports_color() + + log_info("") + log_info("=== Environment Variable Differences ===") + + if not diffs: + log_info("No differences to display") + return + + for count, (key, (env_value, example_value)) in enumerate(diffs.items(), start=1): + print() + if use_colors: + print(f"{YELLOW}[{count}] {key}{NC}") + print(f" {GREEN}.env (current){NC} : {env_value}") + print(f" {BLUE}.env.example (recommended){NC} : {example_value}") + else: + print(f"[{count}] {key}") + print(f" .env (current) : {env_value}") + print(f" .env.example (recommended) : {example_value}") + + analysis = analyze_value_change(env_value, example_value) + if analysis: + print(analysis) + + print() + log_info("=== Difference Analysis Complete ===") + log_info("Note: Consider changing to the recommended values above.") + log_info("Current implementation preserves .env values.") + print() + + +def detect_removed_variables(env_vars: dict[str, str], example_vars: dict[str, str]) -> list[str]: + """Identify variables present in .env but absent from .env.example. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Sorted list of variable names that no longer appear in .env.example. + """ + log_info("Detecting removed environment variables...") + + removed = sorted(set(env_vars) - set(example_vars)) + + if removed: + log_warning("The following environment variables have been removed from .env.example:") + for var in removed: + log_warning(f" - {var}") + log_warning("Consider manually removing these variables from .env") + else: + log_success("No removed environment variables found") + + return removed + + +def sync_env_file(work_dir: Path, env_vars: dict[str, str], diffs: dict[str, tuple[str, str]]) -> None: + """Rewrite .env based on .env.example while preserving custom values. + + The output file follows the exact line structure of .env.example + (preserving comments, blank lines, and ordering). For every variable + that exists in .env with a different value from the example, the + current .env value is kept. Variables that are new in .env.example + (not present in .env at all) are added with the example's default. + + Args: + work_dir: Directory containing .env and .env.example. + env_vars: Parsed key/value pairs from the original .env. + diffs: Keys whose .env values differ from .env.example (to preserve). + """ + log_info("Starting partial synchronization of .env file...") + + example_file = work_dir / ".env.example" + new_env_file = work_dir / ".env.new" + + # Keys whose current .env value should override the example default + preserved_keys: set[str] = set(diffs.keys()) + + preserved_count = 0 + updated_count = 0 + + env_var_pattern = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=") + + with example_file.open(encoding="utf-8") as src, new_env_file.open("w", encoding="utf-8") as dst: + for line in src: + raw_line = line.rstrip("\n") + match = env_var_pattern.match(raw_line) + if match: + key = match.group(1) + if key in preserved_keys: + # Write the preserved value from .env + dst.write(f"{key}={env_vars[key]}\n") + log_info(f" Preserved: {key} (.env value)") + preserved_count += 1 + else: + # Use the example value (covers new vars and unchanged ones) + dst.write(line if line.endswith("\n") else raw_line + "\n") + updated_count += 1 + else: + # Blank line, comment, or non-variable line — keep as-is + dst.write(line if line.endswith("\n") else raw_line + "\n") + + # Atomically replace the original .env + try: + new_env_file.replace(work_dir / ".env") + except OSError as exc: + log_error(f"Failed to replace .env file: {exc}") + new_env_file.unlink(missing_ok=True) + sys.exit(1) + + log_success("Successfully created new .env file") + log_success("Partial synchronization of .env file completed") + log_info(f" Preserved .env values: {preserved_count}") + log_info(f" Updated to .env.example values: {updated_count}") + + +def show_statistics(work_dir: Path) -> None: + """Print a summary of variable counts from both env files. + + Args: + work_dir: Directory containing .env and .env.example. + """ + log_info("Synchronization statistics:") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + example_count = len(parse_env_file(example_file)) if example_file.exists() else 0 + env_count = len(parse_env_file(env_file)) if env_file.exists() else 0 + + log_info(f" .env.example environment variables: {example_count}") + log_info(f" .env environment variables: {env_count}") + + +def build_arg_parser() -> argparse.ArgumentParser: + """Build and return the CLI argument parser. + + Returns: + Configured ArgumentParser instance. + """ + parser = argparse.ArgumentParser( + prog="dify-env-sync", + description=( + "Synchronize .env with .env.example: add new variables, " + "preserve custom values, and report removed variables." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Examples:\n" + " # Run from the docker/ directory (default)\n" + " python dify-env-sync.py\n\n" + " # Specify a custom working directory\n" + " python dify-env-sync.py --dir /path/to/docker\n" + ), + ) + parser.add_argument( + "--dir", + metavar="DIRECTORY", + default=".", + help="Working directory containing .env and .env.example (default: current directory)", + ) + parser.add_argument( + "--no-backup", + action="store_true", + default=False, + help="Skip creating a timestamped backup of the existing .env file", + ) + return parser + + +def main() -> None: + """Orchestrate the complete environment variable synchronization process.""" + parser = build_arg_parser() + args = parser.parse_args() + + work_dir = Path(args.dir).resolve() + + log_info("=== Dify Environment Variables Synchronization Script ===") + log_info(f"Execution started: {datetime.now()}") + log_info(f"Working directory: {work_dir}") + + # 1. Verify prerequisites + check_files(work_dir) + + # 2. Backup existing .env + if not args.no_backup: + create_backup(work_dir) + + # 3. Parse both files + env_vars = parse_env_file(work_dir / ".env") + example_vars = parse_env_file(work_dir / ".env.example") + + # 4. Report differences (values that changed in the example) + diffs = detect_differences(env_vars, example_vars) + + # 5. Report variables removed from the example + detect_removed_variables(env_vars, example_vars) + + # 6. Rewrite .env + sync_env_file(work_dir, env_vars, diffs) + + # 7. Print summary statistics + show_statistics(work_dir) + + log_success("=== Synchronization process completed successfully ===") + log_info(f"Execution finished: {datetime.now()}") + + +if __name__ == "__main__": + main() diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 939f23136a..04bd2858ff 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.1 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.1 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.1 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.1 + image: langgenius/dify-web:1.13.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b6b6f299cf..6e11cac678 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -345,6 +345,9 @@ x-shared-env: &shared-api-worker-env BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: ${BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER:-DEFAULT_ANALYZER} BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: ${BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE:-COARSE_MODE} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT:-500} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO:-0.05} + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: ${BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS:-300} VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} @@ -728,7 +731,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.1 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -770,7 +773,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.1 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -809,7 +812,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.1 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -839,7 +842,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.1 + image: langgenius/dify-web:1.13.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/ssrf_proxy/squid.conf.template b/docker/ssrf_proxy/squid.conf.template index 256e669c8d..fbe9ebc448 100644 --- a/docker/ssrf_proxy/squid.conf.template +++ b/docker/ssrf_proxy/squid.conf.template @@ -28,6 +28,7 @@ http_access deny manager http_access allow localhost include /etc/squid/conf.d/*.conf http_access deny all +tcp_outgoing_address 0.0.0.0 ################################## Proxy Server ################################ http_port ${HTTP_PORT} diff --git a/docs/eu-ai-act-compliance.md b/docs/eu-ai-act-compliance.md new file mode 100644 index 0000000000..5fa29eed3f --- /dev/null +++ b/docs/eu-ai-act-compliance.md @@ -0,0 +1,186 @@ +# EU AI Act Compliance Guide for Dify Deployers + +Dify is an LLMOps platform for building RAG pipelines, agents, and AI workflows. If you deploy Dify in the EU — whether self-hosted or using a cloud provider — the EU AI Act applies to your deployment. This guide covers what the regulation requires and how Dify's architecture maps to those requirements. + +## Is your system in scope? + +The detailed obligations in Articles 12, 13, and 14 only apply to **high-risk AI systems** as defined in Annex III of the EU AI Act. A Dify application is high-risk if it is used for: + +- **Recruitment and HR** — screening candidates, evaluating employee performance, allocating tasks +- **Credit scoring and insurance** — assessing creditworthiness or setting premiums +- **Law enforcement** — profiling, criminal risk assessment, border control +- **Critical infrastructure** — managing energy, water, transport, or telecommunications systems +- **Education assessment** — grading students, determining admissions +- **Essential public services** — evaluating eligibility for benefits, housing, or emergency services + +Most Dify deployments (customer-facing chatbots, internal knowledge bases, content generation workflows) are **not** high-risk. If your Dify application does not fall into one of the categories above: + +- **Article 50** (end-user transparency) still applies if users interact with your application directly. See the [Article 50 section](#article-50-end-user-transparency) below. +- **GDPR** still applies if you process personal data. See the [GDPR section](#gdpr-considerations) below. +- The high-risk obligations (Articles 9-15) are less likely to apply, but risk classification is context-dependent. **Do not self-classify without legal review.** Focus on Article 50 (transparency) and GDPR (data protection) as your baseline obligations. + +If you are unsure whether your use case qualifies as high-risk, consult a qualified legal professional before proceeding. + +## Self-hosted vs cloud: different compliance profiles + +| Deployment | Your role | Dify's role | Who handles compliance? | +|-----------|----------|-------------|------------------------| +| **Self-hosted** | Provider and deployer | Framework provider — obligations under Article 25 apply only if Dify is placed on the market or put into service as part of a complete AI system bearing its name or trademark | You | +| **Dify Cloud** | Deployer | Provider and processor | Shared — Dify handles SOC 2 and GDPR for the platform; you handle AI Act obligations for your specific use case | + +Dify Cloud already has SOC 2 Type II and GDPR compliance for the platform itself. But the EU AI Act adds obligations specific to AI systems that SOC 2 does not cover: risk classification, technical documentation, transparency, and human oversight. + +## Supported providers and services + +Dify integrates with a broad range of AI providers and data stores. The following are the key ones relevant to compliance: + +- **AI providers:** HuggingFace (core), plus integrations with OpenAI, Anthropic, Google, and 100+ models via provider plugins +- **Model identifiers include:** gpt-4o, gpt-3.5-turbo, claude-3-opus, gemini-2.5-flash, whisper-1, and others +- **Vector database connections:** Extensive RAG infrastructure supporting numerous vector stores + +Dify's plugin architecture means actual provider usage depends on your configuration. Document which providers and models are active in your deployment. + +## Data flow diagram + +A typical Dify RAG deployment: + +```mermaid +graph LR + USER((User)) -->|query| DIFY[Dify Platform] + DIFY -->|prompts| LLM([LLM Provider]) + LLM -->|responses| DIFY + DIFY -->|documents| EMBED([Embedding Model]) + EMBED -->|vectors| DIFY + DIFY -->|store/retrieve| VS[(Vector Store)] + DIFY -->|knowledge| KB[(Knowledge Base)] + DIFY -->|response| USER + + classDef processor fill:#60a5fa,stroke:#1e40af,color:#000 + classDef controller fill:#4ade80,stroke:#166534,color:#000 + classDef app fill:#a78bfa,stroke:#5b21b6,color:#000 + classDef user fill:#f472b6,stroke:#be185d,color:#000 + + class USER user + class DIFY app + class LLM processor + class EMBED processor + class VS controller + class KB controller +``` + +**GDPR roles** (providers are typically processors for customer-submitted data, but the exact role depends on each provider's terms of service and processing purpose; deployers should review each provider's DPA): +- **Cloud LLM providers (OpenAI, Anthropic, Google)** typically act as processors — requires DPA. +- **Cloud embedding services** typically act as processors — requires DPA. +- **Self-hosted vector stores (Weaviate, Qdrant, pgvector):** Your organization remains the controller — no third-party transfer. +- **Cloud vector stores (Pinecone, Zilliz Cloud)** typically act as processors — requires DPA. +- **Knowledge base documents:** Your organization is the controller — stored in your infrastructure. + +## Article 11: Technical documentation + +High-risk systems need Annex IV documentation. For Dify deployments, key sections include: + +| Section | What Dify provides | What you must document | +|---------|-------------------|----------------------| +| General description | Platform capabilities, supported models | Your specific use case, intended users, deployment context | +| Development process | Dify's architecture, plugin system | Your RAG pipeline design, prompt engineering, knowledge base curation | +| Monitoring | Dify's built-in logging and analytics | Your monitoring plan, alert thresholds, incident response | +| Performance metrics | Dify's evaluation features | Your accuracy benchmarks, quality thresholds, bias testing | +| Risk management | — | Risk assessment for your specific use case | + +Some sections can be derived from Dify's architecture and your deployment configuration, as shown in the table above. The remaining sections require your input. + +## Article 12: Record-keeping + +Dify's built-in logging covers several Article 12 requirements: + +| Requirement | Dify Feature | Status | +|------------|-------------|--------| +| Conversation logs | Full conversation history with timestamps | **Covered** | +| Model tracking | Model name recorded per interaction | **Covered** | +| Token usage | Token counts per message | **Covered** | +| Cost tracking | Cost per conversation (if provider reports it) | **Partial** | +| Document retrieval | RAG source documents logged | **Covered** | +| User identification | User session tracking | **Covered** | +| Error logging | Failed generation logs | **Covered** | +| Data retention | Configurable | **Your responsibility** | + +**Retention periods:** The required retention period depends on your role under the Act. Article 18 requires **providers** of high-risk systems to retain logs and technical documentation for **10 years** after market placement. Article 26(6) requires **deployers** to retain logs for at least **6 months**. If you self-host Dify and have substantially modified the system, you may be classified as a provider rather than a deployer. Confirm the applicable retention period with legal counsel. + +## Article 13: Transparency to deployers + +Article 13 requires providers of high-risk AI systems to supply deployers with the information needed to understand and operate the system correctly. This is a **documentation obligation**, not a logging obligation. For Dify deployments, this means the upstream LLM and embedding providers must give you: + +- Instructions for use, including intended purpose and known limitations +- Accuracy metrics and performance benchmarks +- Known or foreseeable risks and residual risks after mitigation +- Technical specifications: input/output formats, training data characteristics, model architecture details + +As a deployer, collect model cards, system documentation, and accuracy reports from each AI provider your Dify application uses. Maintain these as part of your Annex IV technical documentation. + +Dify's platform features provide **supporting evidence** that can inform Article 13 documentation, but they do not satisfy Article 13 on their own: +- **Source attribution** — Dify's RAG citation feature shows which documents informed the response, supporting deployer-side auditing +- **Model identification** — Dify logs which LLM model generates responses, providing evidence for system documentation +- **Conversation logs** — execution history helps compile performance and behavior evidence + +You must independently produce system documentation covering how your specific Dify deployment uses AI, its intended purpose, performance characteristics, and residual risks. + +## Article 50: End-user transparency + +Article 50 requires deployers to inform end users that they are interacting with an AI system. This is a separate obligation from Article 13 and applies even to limited-risk systems. + +For Dify applications serving end users: + +1. **Disclose AI involvement** — tell users they are interacting with an AI system +2. **AI-generated content labeling** — identify AI-generated content as such (e.g., clear labeling in the UI) + +Dify's "citation" feature also supports end-user transparency by showing users which knowledge base documents informed the answer. + +> **Note:** Article 50 applies to chatbots and systems interacting directly with natural persons. It has a separate scope from the high-risk designation under Annex III — it applies even to limited-risk systems. + +## Article 14: Human oversight + +Article 14 requires that high-risk AI systems be designed so that natural persons can effectively oversee them. Dify provides **automated technical safeguards** that support human oversight, but they are not a substitute for it: + +| Dify Feature | What It Does | Oversight Role | +|-------------|-------------|----------------| +| Annotation/feedback system | Human review of AI outputs | **Direct oversight** — humans evaluate and correct AI responses | +| Content moderation | Built-in filtering before responses reach users | **Automated safeguard** — reduces harmful outputs but does not replace human judgment on edge cases | +| Rate limiting | Controls on API usage | **Automated safeguard** — bounds system behavior, supports overseer's ability to maintain control | +| Workflow control | Insert human review steps between AI generation and output | **Oversight enabler** — allows building approval gates into the pipeline | + +These automated controls are necessary building blocks, but Article 14 compliance requires **human oversight procedures** on top of them: +- **Escalation procedures** — define what happens when moderation triggers or edge cases arise (who is notified, what action is taken) +- **Human review pipeline** — for high-stakes decisions, route AI outputs to a qualified person before they take effect +- **Override mechanism** — a human must be able to halt AI responses or override the system's output +- **Competence requirements** — the human overseer must understand the system's capabilities, limitations, and the context of its outputs + +### Recommended pattern + +For high-risk use cases (HR, legal, medical), configure your Dify workflow to require human approval before the AI response is delivered to the end user or acted upon. + +## Knowledge base compliance + +Dify's knowledge base feature has specific compliance implications: + +1. **Data provenance:** Document where your knowledge base documents come from. Article 10 requires data governance for training data; knowledge bases are analogous. +2. **Update tracking:** When you add, remove, or update documents in the knowledge base, log the change. The AI system's behavior changes with its knowledge base. +3. **PII in documents:** If knowledge base documents contain personal data, GDPR applies to the entire RAG pipeline. Implement access controls and consider PII redaction before indexing. +4. **Copyright:** Ensure you have the right to use the documents in your knowledge base for AI-assisted generation. + +## GDPR considerations + +1. **Legal basis** (Article 6): Document why AI processing of user queries is necessary +2. **Data Processing Agreements** (Article 28): Required for each cloud LLM and embedding provider +3. **Data minimization:** Only include necessary context in prompts; avoid sending entire documents when a relevant excerpt suffices +4. **Right to erasure:** If a user requests deletion, ensure their conversations are removed from Dify's logs AND any vector store entries derived from their data +5. **Cross-border transfers:** Providers based outside the EEA — including US-based providers (OpenAI, Anthropic), and any other non-EEA providers you route to — require Standard Contractual Clauses (SCCs) or equivalent safeguards under Chapter V of the GDPR. Review each provider's transfer mechanism individually. + +## Resources + +- [EU AI Act full text](https://artificialintelligenceact.eu/) +- [Dify documentation](https://docs.dify.ai/) +- [Dify SOC 2 compliance](https://dify.ai/trust) + +--- + +*This is not legal advice. Consult a qualified professional for compliance decisions.* diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 7c8a293446..728aa0d054 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -69,6 +69,7 @@ }, "pnpm": { "overrides": { + "flatted@<=3.4.1": "3.4.2", "rollup@>=4.0.0,<4.59.0": "4.59.0" } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index b0aee38cdf..c9081420f5 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -5,6 +5,7 @@ settings: excludeLinksFromLockfile: false overrides: + flatted@<=3.4.1: 3.4.2 rollup@>=4.0.0,<4.59.0: 4.59.0 importers: @@ -754,8 +755,8 @@ packages: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} - flatted@3.4.1: - resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} + flatted@3.4.2: + resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==} follow-redirects@1.15.11: resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} @@ -1849,10 +1850,10 @@ snapshots: flat-cache@4.0.1: dependencies: - flatted: 3.4.1 + flatted: 3.4.2 keyv: 4.5.4 - flatted@3.4.1: {} + flatted@3.4.2: {} follow-redirects@1.15.11: {} diff --git a/web/.env.example b/web/.env.example index ed06ebe2c9..079c3bdeef 100644 --- a/web/.env.example +++ b/web/.env.example @@ -6,19 +6,23 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED NEXT_PUBLIC_BASE_PATH= # The base URL of console application, refers to the Console base URL of WEB service if console domain is # different from api or web app domain. -# example: http://cloud.dify.ai/console/api +# example: https://cloud.dify.ai/console/api NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api # The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from # console or api domain. -# example: http://udify.app/api +# example: https://udify.app/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api -# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly. +# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. +NEXT_PUBLIC_COOKIE_DOMAIN= + +# Dev-only Hono proxy targets. +# The frontend keeps requesting http://localhost:5001 directly, +# the proxy server will forward the request to the target server, +# so that you don't need to run a separate backend server and use online API in development. HONO_PROXY_HOST=127.0.0.1 HONO_PROXY_PORT=5001 HONO_CONSOLE_API_PROXY_TARGET= HONO_PUBLIC_API_PROXY_TARGET= -# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. -NEXT_PUBLIC_COOKIE_DOMAIN= # The API PREFIX for MARKETPLACE NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1 diff --git a/web/README.md b/web/README.md index 1e57e7c6a9..14ca856875 100644 --- a/web/README.md +++ b/web/README.md @@ -1,6 +1,6 @@ # Dify Frontend -This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). +This is a [Next.js] project, but you can dev with [vinext]. ## Getting Started @@ -8,8 +8,11 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next Before starting the web frontend service, please make sure the following environment is ready. -- [Node.js](https://nodejs.org) -- [pnpm](https://pnpm.io) +- [Node.js] +- [pnpm] + +You can also use [Vite+] with the corresponding `vp` commands. +For example, use `vp install` instead of `pnpm install` and `vp test` instead of `pnpm run test`. > [!TIP] > It is recommended to install and enable Corepack to manage package manager versions automatically: @@ -19,7 +22,7 @@ Before starting the web frontend service, please make sure the following environ > corepack enable > ``` > -> Learn more: [Corepack](https://github.com/nodejs/corepack#readme) +> Learn more: [Corepack] First, install the dependencies: @@ -27,31 +30,14 @@ First, install the dependencies: pnpm install ``` -Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements: +Then, configure the environment variables. +Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. +Modify the values of these environment variables according to your requirements: ```bash cp .env.example .env.local ``` -```txt -# For production release, change this to PRODUCTION -NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT -# The deployment edition, SELF_HOSTED -NEXT_PUBLIC_EDITION=SELF_HOSTED -# The base URL of console application, refers to the Console base URL of WEB service if console domain is -# different from api or web app domain. -# example: http://cloud.dify.ai/console/api -NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api -NEXT_PUBLIC_COOKIE_DOMAIN= -# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from -# console or api domain. -# example: http://udify.app/api -NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api - -# SENTRY -NEXT_PUBLIC_SENTRY_DSN= -``` - > [!IMPORTANT] > > 1. When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. The frontend and backend must be under the same top-level domain in order to share authentication cookies. @@ -61,11 +47,16 @@ Finally, run the development server: ```bash pnpm run dev +# or if you are using vinext which provides a better development experience +pnpm run dev:vinext +# (optional) start the dev proxy server so that you can use online API in development +pnpm run dev:proxy ``` -Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. +Open with your browser to see the result. -You can start editing the file under folder `app`. The page auto-updates as you edit the file. +You can start editing the file under folder `app`. +The page auto-updates as you edit the file. ## Deploy @@ -91,7 +82,7 @@ pnpm run start --port=3001 --host=0.0.0.0 ## Storybook -This project uses [Storybook](https://storybook.js.org/) for UI component development. +This project uses [Storybook] for UI component development. To start the storybook server, run: @@ -99,19 +90,24 @@ To start the storybook server, run: pnpm storybook ``` -Open [http://localhost:6006](http://localhost:6006) with your browser to see the result. +Open with your browser to see the result. ## Lint Code If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. -Then follow the [Lint Documentation](./docs/lint.md) to lint the code. +Then follow the [Lint Documentation] to lint the code. ## Test -We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. +We use [Vitest] and [React Testing Library] for Unit Testing. -**📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples. +**📖 Complete Testing Guide**: See [web/docs/test.md] for detailed testing specifications, best practices, and examples. + +> [!IMPORTANT] +> As we are using Vite+, the `vitest` command is not available. +> Please make sure to run tests with `vp` commands. +> For example, use `npx vp test` instead of `npx vitest`. Run test: @@ -119,12 +115,17 @@ Run test: pnpm test ``` +> [!NOTE] +> Our test is not fully stable yet, and we are actively working on improving it. +> If you encounter test failures only in CI but not locally, please feel free to ignore them and report the issue to us. +> You can try to re-run the test in CI, and it may pass successfully. + ### Example Code If you are not familiar with writing tests, refer to: -- [classnames.spec.ts](./utils/classnames.spec.ts) - Utility function test example -- [index.spec.tsx](./app/components/base/button/index.spec.tsx) - Component test example +- [classnames.spec.ts] - Utility function test example +- [index.spec.tsx] - Component test example ### Analyze Component Complexity @@ -134,7 +135,7 @@ Before writing tests, use the script to analyze component complexity: pnpm analyze-component app/components/your-component/index.tsx ``` -This will help you determine the testing strategy. See [web/testing/testing.md](./testing/testing.md) for details. +This will help you determine the testing strategy. See [web/testing/testing.md] for details. ## Documentation @@ -142,4 +143,19 @@ Visit to view the full documentation. ## Community -The Dify community can be found on [Discord community](https://discord.gg/5AEfbxcd9k), where you can ask questions, voice ideas, and share your projects. +The Dify community can be found on [Discord community], where you can ask questions, voice ideas, and share your projects. + +[Corepack]: https://github.com/nodejs/corepack#readme +[Discord community]: https://discord.gg/5AEfbxcd9k +[Lint Documentation]: ./docs/lint.md +[Next.js]: https://nextjs.org +[Node.js]: https://nodejs.org +[React Testing Library]: https://testing-library.com/docs/react-testing-library/intro +[Storybook]: https://storybook.js.org +[Vite+]: https://viteplus.dev +[Vitest]: https://vitest.dev +[classnames.spec.ts]: ./utils/classnames.spec.ts +[index.spec.tsx]: ./app/components/base/button/index.spec.tsx +[pnpm]: https://pnpm.io +[vinext]: https://github.com/cloudflare/vinext +[web/docs/test.md]: ./docs/test.md diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index c3e8410955..c5766878a1 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -29,7 +29,7 @@ const mockOnPlanInfoChanged = vi.fn() const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) let mockDeleteMutationPending = false -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), @@ -57,7 +57,7 @@ vi.mock('@headlessui/react', async () => { } }) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 079f667dbc..1be7e56086 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -38,7 +38,7 @@ let mockShowTagManagementModal = false const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -46,7 +46,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (_loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComponent = (props: Record) => { return
diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 4ac9824ddd..bc1f7a3a06 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -35,7 +35,7 @@ const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() const mockOnPlanInfoChanged = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -117,7 +117,7 @@ vi.mock('ahooks', async () => { }) // Mock dynamically loaded modals with test stubs -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 4891760df4..64d358cbe6 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -64,7 +64,7 @@ vi.mock('@/service/use-education', () => ({ // ─── Navigation mocks ─────────────────────────────────────────────────────── const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx index e01d9250fd..0c1efbe1af 100644 --- a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx +++ b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx @@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type' import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { ALL_PLANS } from '@/app/components/billing/config' import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher' import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item' @@ -21,7 +22,6 @@ let mockAppCtx: Record = {} const mockFetchSubscriptionUrls = vi.fn() const mockInvoices = vi.fn() const mockOpenAsyncWindow = vi.fn() -const mockToastNotify = vi.fn() // ─── Context mocks ─────────────────────────────────────────────────────────── vi.mock('@/context/app-context', () => ({ @@ -49,12 +49,8 @@ vi.mock('@/hooks/use-async-window-open', () => ({ useAsyncWindowOpen: () => mockOpenAsyncWindow, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -82,12 +78,15 @@ const renderCloudPlanItem = ({ canPay = true, }: RenderCloudPlanItemOptions = {}) => { return render( - , + <> + + + , ) } @@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.dismiss() setupAppContext() mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' }) mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' }) @@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => { await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should not proceed with payment expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled() diff --git a/web/__tests__/billing/education-verification-flow.test.tsx b/web/__tests__/billing/education-verification-flow.test.tsx index 8c35cd9a8c..707f1d690a 100644 --- a/web/__tests__/billing/education-verification-flow.test.tsx +++ b/web/__tests__/billing/education-verification-flow.test.tsx @@ -63,7 +63,7 @@ vi.mock('@/service/use-billing', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/partner-stack-flow.test.tsx b/web/__tests__/billing/partner-stack-flow.test.tsx index 4f265478cd..fe642ac70b 100644 --- a/web/__tests__/billing/partner-stack-flow.test.tsx +++ b/web/__tests__/billing/partner-stack-flow.test.tsx @@ -18,7 +18,7 @@ let mockSearchParams = new URLSearchParams() const mockMutateAsync = vi.fn() // ─── Module mocks ──────────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => mockSearchParams, useRouter: () => ({ push: vi.fn() }), usePathname: () => '/', diff --git a/web/__tests__/billing/pricing-modal-flow.test.tsx b/web/__tests__/billing/pricing-modal-flow.test.tsx index 7326ee3559..2ec7298618 100644 --- a/web/__tests__/billing/pricing-modal-flow.test.tsx +++ b/web/__tests__/billing/pricing-modal-flow.test.tsx @@ -51,7 +51,7 @@ vi.mock('@/hooks/use-async-window-open', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/self-hosted-plan-flow.test.tsx b/web/__tests__/billing/self-hosted-plan-flow.test.tsx index 810d36da8a..a3386d0092 100644 --- a/web/__tests__/billing/self-hosted-plan-flow.test.tsx +++ b/web/__tests__/billing/self-hosted-plan-flow.test.tsx @@ -10,12 +10,12 @@ import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config' import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item' import { SelfHostedPlan } from '@/app/components/billing/type' let mockAppCtx: Record = {} -const mockToastNotify = vi.fn() const originalLocation = window.location let assignedHref = '' @@ -40,10 +40,6 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({ AwsMarketplaceDark: () => , })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({ default: ({ plan }: { plan: string }) => (
Features
@@ -57,10 +53,20 @@ const setupAppContext = (overrides: Record = {}) => { } } +const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => { + return render( + <> + + + , + ) +} + describe('Self-Hosted Plan Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.dismiss() setupAppContext() // Mock window.location with minimal getter/setter (Location props are non-enumerable) @@ -85,14 +91,14 @@ describe('Self-Hosted Plan Flow', () => { // ─── 1. Plan Rendering ────────────────────────────────────────────────── describe('Plan rendering', () => { it('should render community plan with name and description', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument() expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument() }) it('should render premium plan with cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument() expect(screen.getByTestId('icon-azure')).toBeInTheDocument() @@ -100,39 +106,39 @@ describe('Self-Hosted Plan Flow', () => { }) it('should render enterprise plan without cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument() expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument() }) it('should not show price tip for community (free) plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument() }) it('should show price tip for premium plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument() }) it('should render features list for each plan', () => { - const { unmount: unmount1 } = render() + const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument() unmount1() - const { unmount: unmount2 } = render() + const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument() unmount2() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument() }) it('should show AWS marketplace icon for premium plan button', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument() }) @@ -142,7 +148,7 @@ describe('Self-Hosted Plan Flow', () => { describe('Navigation flow', () => { it('should redirect to GitHub when clicking community plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) @@ -152,7 +158,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to AWS Marketplace when clicking premium plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) @@ -162,7 +168,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to Typeform when clicking enterprise plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) @@ -176,15 +182,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks community button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should NOT redirect expect(assignedHref).toBe('') @@ -193,15 +197,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks premium button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) @@ -209,15 +211,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks enterprise button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) diff --git a/web/__tests__/check-components-diff-coverage.test.ts b/web/__tests__/check-components-diff-coverage.test.ts deleted file mode 100644 index 62e5ff5ed5..0000000000 --- a/web/__tests__/check-components-diff-coverage.test.ts +++ /dev/null @@ -1,221 +0,0 @@ -import { - buildGitDiffRevisionArgs, - getChangedBranchCoverage, - getChangedStatementCoverage, - getIgnoredChangedLinesFromSource, - normalizeToRepoRelative, - parseChangedLineMap, -} from '../scripts/check-components-diff-coverage-lib.mjs' - -describe('check-components-diff-coverage helpers', () => { - it('should build exact and merge-base git diff revision args', () => { - expect(buildGitDiffRevisionArgs('base-sha', 'head-sha', 'exact')).toEqual(['base-sha', 'head-sha']) - expect(buildGitDiffRevisionArgs('base-sha', 'head-sha')).toEqual(['base-sha...head-sha']) - }) - - it('should parse changed line maps from unified diffs', () => { - const diff = [ - 'diff --git a/web/app/components/share/a.ts b/web/app/components/share/a.ts', - '+++ b/web/app/components/share/a.ts', - '@@ -10,0 +11,2 @@', - '+const a = 1', - '+const b = 2', - 'diff --git a/web/app/components/base/b.ts b/web/app/components/base/b.ts', - '+++ b/web/app/components/base/b.ts', - '@@ -20 +21 @@', - '+const c = 3', - 'diff --git a/web/README.md b/web/README.md', - '+++ b/web/README.md', - '@@ -1 +1 @@', - '+ignore me', - ].join('\n') - - const lineMap = parseChangedLineMap(diff, (filePath: string) => filePath.startsWith('web/app/components/')) - - expect([...lineMap.entries()]).toEqual([ - ['web/app/components/share/a.ts', new Set([11, 12])], - ['web/app/components/base/b.ts', new Set([21])], - ]) - }) - - it('should normalize coverage and absolute paths to repo-relative paths', () => { - const repoRoot = '/repo' - const webRoot = '/repo/web' - - expect(normalizeToRepoRelative('web/app/components/share/a.ts', { - appComponentsCoveragePrefix: 'app/components/', - appComponentsPrefix: 'web/app/components/', - repoRoot, - sharedTestPrefix: 'web/__tests__/', - webRoot, - })).toBe('web/app/components/share/a.ts') - - expect(normalizeToRepoRelative('app/components/share/a.ts', { - appComponentsCoveragePrefix: 'app/components/', - appComponentsPrefix: 'web/app/components/', - repoRoot, - sharedTestPrefix: 'web/__tests__/', - webRoot, - })).toBe('web/app/components/share/a.ts') - - expect(normalizeToRepoRelative('/repo/web/app/components/share/a.ts', { - appComponentsCoveragePrefix: 'app/components/', - appComponentsPrefix: 'web/app/components/', - repoRoot, - sharedTestPrefix: 'web/__tests__/', - webRoot, - })).toBe('web/app/components/share/a.ts') - }) - - it('should calculate changed statement coverage from changed lines', () => { - const entry = { - s: { 0: 1, 1: 0 }, - statementMap: { - 0: { start: { line: 10 }, end: { line: 10 } }, - 1: { start: { line: 12 }, end: { line: 13 } }, - }, - } - - const coverage = getChangedStatementCoverage(entry, new Set([10, 12])) - - expect(coverage).toEqual({ - covered: 1, - total: 2, - uncoveredLines: [12], - }) - }) - - it('should report the first changed line inside a multi-line uncovered statement', () => { - const entry = { - s: { 0: 0 }, - statementMap: { - 0: { start: { line: 10 }, end: { line: 14 } }, - }, - } - - const coverage = getChangedStatementCoverage(entry, new Set([13, 14])) - - expect(coverage).toEqual({ - covered: 0, - total: 1, - uncoveredLines: [13], - }) - }) - - it('should fail changed lines when a source file has no coverage entry', () => { - const coverage = getChangedStatementCoverage(undefined, new Set([42, 43])) - - expect(coverage).toEqual({ - covered: 0, - total: 2, - uncoveredLines: [42, 43], - }) - }) - - it('should calculate changed branch coverage using changed branch definitions', () => { - const entry = { - b: { - 0: [1, 0], - }, - branchMap: { - 0: { - line: 20, - loc: { start: { line: 20 }, end: { line: 20 } }, - locations: [ - { start: { line: 20 }, end: { line: 20 } }, - { start: { line: 21 }, end: { line: 21 } }, - ], - type: 'if', - }, - }, - } - - const coverage = getChangedBranchCoverage(entry, new Set([20])) - - expect(coverage).toEqual({ - covered: 1, - total: 2, - uncoveredBranches: [ - { armIndex: 1, line: 21 }, - ], - }) - }) - - it('should report the first changed line inside a multi-line uncovered branch arm', () => { - const entry = { - b: { - 0: [0, 0], - }, - branchMap: { - 0: { - line: 30, - loc: { start: { line: 30 }, end: { line: 35 } }, - locations: [ - { start: { line: 31 }, end: { line: 34 } }, - { start: { line: 35 }, end: { line: 38 } }, - ], - type: 'if', - }, - }, - } - - const coverage = getChangedBranchCoverage(entry, new Set([33])) - - expect(coverage).toEqual({ - covered: 0, - total: 1, - uncoveredBranches: [ - { armIndex: 0, line: 33 }, - ], - }) - }) - - it('should require all branch arms when the branch condition changes', () => { - const entry = { - b: { - 0: [0, 0], - }, - branchMap: { - 0: { - line: 30, - loc: { start: { line: 30 }, end: { line: 35 } }, - locations: [ - { start: { line: 31 }, end: { line: 34 } }, - { start: { line: 35 }, end: { line: 38 } }, - ], - type: 'if', - }, - }, - } - - const coverage = getChangedBranchCoverage(entry, new Set([30])) - - expect(coverage).toEqual({ - covered: 0, - total: 2, - uncoveredBranches: [ - { armIndex: 0, line: 31 }, - { armIndex: 1, line: 35 }, - ], - }) - }) - - it('should ignore changed lines with valid pragma reasons and report invalid pragmas', () => { - const sourceCode = [ - 'const a = 1', - 'const b = 2 // diff-coverage-ignore-line: defensive fallback', - 'const c = 3 // diff-coverage-ignore-line:', - 'const d = 4 // diff-coverage-ignore-line: not changed', - ].join('\n') - - const result = getIgnoredChangedLinesFromSource(sourceCode, new Set([2, 3])) - - expect([...result.effectiveChangedLines]).toEqual([3]) - expect([...result.ignoredLines.entries()]).toEqual([ - [2, 'defensive fallback'], - ]) - expect(result.invalidPragmas).toEqual([ - { line: 3, reason: 'missing ignore reason' }, - ]) - }) -}) diff --git a/web/__tests__/component-coverage-filters.test.ts b/web/__tests__/component-coverage-filters.test.ts deleted file mode 100644 index cacc1e2142..0000000000 --- a/web/__tests__/component-coverage-filters.test.ts +++ /dev/null @@ -1,115 +0,0 @@ -import fs from 'node:fs' -import os from 'node:os' -import path from 'node:path' -import { afterEach, describe, expect, it } from 'vitest' -import { - collectComponentCoverageExcludedFiles, - COMPONENT_COVERAGE_EXCLUDE_LABEL, - getComponentCoverageExclusionReasons, -} from '../scripts/component-coverage-filters.mjs' - -describe('component coverage filters', () => { - describe('getComponentCoverageExclusionReasons', () => { - it('should exclude type-only files by basename', () => { - expect( - getComponentCoverageExclusionReasons( - 'web/app/components/share/text-generation/types.ts', - 'export type ShareMode = "run-once" | "run-batch"', - ), - ).toContain('type-only') - }) - - it('should exclude pure barrel files', () => { - expect( - getComponentCoverageExclusionReasons( - 'web/app/components/base/amplitude/index.ts', - [ - 'export { default } from "./AmplitudeProvider"', - 'export { resetUser, trackEvent } from "./utils"', - ].join('\n'), - ), - ).toContain('pure-barrel') - }) - - it('should exclude generated files from marker comments', () => { - expect( - getComponentCoverageExclusionReasons( - 'web/app/components/base/icons/src/vender/workflow/Answer.tsx', - [ - '// GENERATE BY script', - '// DON NOT EDIT IT MANUALLY', - 'export default function Icon() {', - ' return null', - '}', - ].join('\n'), - ), - ).toContain('generated') - }) - - it('should exclude pure static files with exported constants only', () => { - expect( - getComponentCoverageExclusionReasons( - 'web/app/components/workflow/note-node/constants.ts', - [ - 'import { NoteTheme } from "./types"', - 'export const CUSTOM_NOTE_NODE = "custom-note"', - 'export const THEME_MAP = {', - ' [NoteTheme.blue]: { title: "bg-blue-100" },', - '}', - ].join('\n'), - ), - ).toContain('pure-static') - }) - - it('should keep runtime logic files tracked', () => { - expect( - getComponentCoverageExclusionReasons( - 'web/app/components/workflow/nodes/trigger-schedule/default.ts', - [ - 'const validate = (value: string) => value.trim()', - 'export const nodeDefault = {', - ' value: validate("x"),', - '}', - ].join('\n'), - ), - ).toEqual([]) - }) - }) - - describe('collectComponentCoverageExcludedFiles', () => { - const tempDirs: string[] = [] - - afterEach(() => { - for (const dir of tempDirs) - fs.rmSync(dir, { recursive: true, force: true }) - tempDirs.length = 0 - }) - - it('should collect excluded files for coverage config and keep runtime files out', () => { - const rootDir = fs.mkdtempSync(path.join(os.tmpdir(), 'component-coverage-filters-')) - tempDirs.push(rootDir) - - fs.mkdirSync(path.join(rootDir, 'barrel'), { recursive: true }) - fs.mkdirSync(path.join(rootDir, 'icons'), { recursive: true }) - fs.mkdirSync(path.join(rootDir, 'static'), { recursive: true }) - fs.mkdirSync(path.join(rootDir, 'runtime'), { recursive: true }) - - fs.writeFileSync(path.join(rootDir, 'barrel', 'index.ts'), 'export { default } from "./Button"\n') - fs.writeFileSync(path.join(rootDir, 'icons', 'generated-icon.tsx'), '// @generated\nexport default function Icon() { return null }\n') - fs.writeFileSync(path.join(rootDir, 'static', 'constants.ts'), 'export const COLORS = { primary: "#fff" }\n') - fs.writeFileSync(path.join(rootDir, 'runtime', 'config.ts'), 'export const config = makeConfig()\n') - fs.writeFileSync(path.join(rootDir, 'runtime', 'types.ts'), 'export type Config = { value: string }\n') - - expect(collectComponentCoverageExcludedFiles(rootDir, { pathPrefix: 'app/components' })).toEqual([ - 'app/components/barrel/index.ts', - 'app/components/icons/generated-icon.tsx', - 'app/components/runtime/types.ts', - 'app/components/static/constants.ts', - ]) - }) - }) - - it('should describe the excluded coverage categories', () => { - expect(COMPONENT_COVERAGE_EXCLUDE_LABEL).toBe('type-only files, pure barrel files, generated files, pure static files') - }) -}) diff --git a/web/__tests__/components-coverage-common.test.ts b/web/__tests__/components-coverage-common.test.ts deleted file mode 100644 index ab189ed854..0000000000 --- a/web/__tests__/components-coverage-common.test.ts +++ /dev/null @@ -1,72 +0,0 @@ -import { - getCoverageStats, - isRelevantTestFile, - isTrackedComponentSourceFile, - loadTrackedCoverageEntries, -} from '../scripts/components-coverage-common.mjs' - -describe('components coverage common helpers', () => { - it('should identify tracked component source files and relevant tests', () => { - const excludedComponentCoverageFiles = new Set([ - 'web/app/components/share/types.ts', - ]) - - expect(isTrackedComponentSourceFile('web/app/components/share/index.tsx', excludedComponentCoverageFiles)).toBe(true) - expect(isTrackedComponentSourceFile('web/app/components/share/types.ts', excludedComponentCoverageFiles)).toBe(false) - expect(isTrackedComponentSourceFile('web/app/components/provider/index.tsx', excludedComponentCoverageFiles)).toBe(false) - - expect(isRelevantTestFile('web/__tests__/share/text-generation-run-once-flow.test.tsx')).toBe(true) - expect(isRelevantTestFile('web/app/components/share/__tests__/index.spec.tsx')).toBe(true) - expect(isRelevantTestFile('web/utils/format.spec.ts')).toBe(false) - }) - - it('should load only tracked coverage entries from mixed coverage paths', () => { - const context = { - excludedComponentCoverageFiles: new Set([ - 'web/app/components/share/types.ts', - ]), - repoRoot: '/repo', - webRoot: '/repo/web', - } - const coverage = { - '/repo/web/app/components/provider/index.tsx': { - path: '/repo/web/app/components/provider/index.tsx', - statementMap: { 0: { start: { line: 1 }, end: { line: 1 } } }, - s: { 0: 1 }, - }, - 'app/components/share/index.tsx': { - path: 'app/components/share/index.tsx', - statementMap: { 0: { start: { line: 2 }, end: { line: 2 } } }, - s: { 0: 1 }, - }, - 'app/components/share/types.ts': { - path: 'app/components/share/types.ts', - statementMap: { 0: { start: { line: 3 }, end: { line: 3 } } }, - s: { 0: 1 }, - }, - } - - expect([...loadTrackedCoverageEntries(coverage, context).keys()]).toEqual([ - 'web/app/components/share/index.tsx', - ]) - }) - - it('should calculate coverage stats using statement-derived line hits', () => { - const entry = { - b: { 0: [1, 0] }, - f: { 0: 1, 1: 0 }, - s: { 0: 1, 1: 0 }, - statementMap: { - 0: { start: { line: 10 }, end: { line: 10 } }, - 1: { start: { line: 12 }, end: { line: 13 } }, - }, - } - - expect(getCoverageStats(entry)).toEqual({ - branches: { covered: 1, total: 2 }, - functions: { covered: 1, total: 2 }, - lines: { covered: 1, total: 2 }, - statements: { covered: 1, total: 2 }, - }) - }) -}) diff --git a/web/__tests__/datasets/dataset-settings-flow.test.tsx b/web/__tests__/datasets/dataset-settings-flow.test.tsx index 607cd8c2d5..b4a5e78326 100644 --- a/web/__tests__/datasets/dataset-settings-flow.test.tsx +++ b/web/__tests__/datasets/dataset-settings-flow.test.tsx @@ -19,6 +19,10 @@ import { RETRIEVE_METHOD } from '@/types/app' // --- Mocks --- +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() const mockUpdateDatasetSetting = vi.fn().mockResolvedValue({}) @@ -55,8 +59,11 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + success: vi.fn(), + }, })) // --- Dataset factory --- @@ -311,7 +318,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { describe('Form Submission Validation → All Fields Together', () => { it('should reject empty name on save', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -322,10 +329,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() }) diff --git a/web/__tests__/datasets/document-management.test.tsx b/web/__tests__/datasets/document-management.test.tsx index 8aedd4fc63..f9d80520ed 100644 --- a/web/__tests__/datasets/document-management.test.tsx +++ b/web/__tests__/datasets/document-management.test.tsx @@ -13,7 +13,7 @@ import { DataSourceType } from '@/models/datasets' import { renderHookWithNuqs } from '@/test/nuqs-testing' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => new URLSearchParams(''), useRouter: () => ({ push: mockPush }), usePathname: () => '/datasets/ds-1/documents', diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx index 6b348cd15b..5cb115830e 100644 --- a/web/__tests__/document-detail-navigation-fix.test.tsx +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -7,12 +7,12 @@ import type { Mock } from 'vitest' */ import { fireEvent, render, screen } from '@testing-library/react' -import { useRouter } from 'next/navigation' +import { useRouter } from '@/next/navigation' import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' // Mock Next.js router const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: mockPush, })), diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 9231ac6199..cacd6331f8 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -8,7 +8,7 @@ const replaceMock = vi.fn() const backMock = vi.fn() const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/test-app'), useRouter: vi.fn(() => ({ replace: replaceMock, diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 901218e76b..04597ccfeb 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -4,7 +4,7 @@ import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/sample-app'), useSearchParams: vi.fn(() => { const params = new URLSearchParams() diff --git a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx index e2c18bcc4f..64dd5321ac 100644 --- a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx +++ b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx @@ -7,19 +7,23 @@ */ import type { InstalledApp } from '@/models/explore' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import Toast from '@/app/components/base/toast' import SideBar from '@/app/components/explore/sidebar' import { MediaType } from '@/hooks/use-breakpoints' import { AppModeEnum } from '@/types/app' +const { mockToastSuccess } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), +})) + let mockMediaType: string = MediaType.pc const mockSegments = ['apps'] const mockPush = vi.fn() const mockUninstall = vi.fn() const mockUpdatePinStatus = vi.fn() let mockInstalledApps: InstalledApp[] = [] +let mockIsUninstallPending = false -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegments: () => mockSegments, useRouter: () => ({ push: mockPush, @@ -42,12 +46,24 @@ vi.mock('@/service/use-explore', () => ({ }), useUninstallApp: () => ({ mutateAsync: mockUninstall, + isPending: mockIsUninstallPending, }), useUpdateAppPinStatus: () => ({ mutateAsync: mockUpdatePinStatus, }), })) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + }, + } +}) + const createInstalledApp = (overrides: Partial = {}): InstalledApp => ({ id: overrides.id ?? 'app-1', uninstallable: overrides.uninstallable ?? false, @@ -74,7 +90,7 @@ describe('Sidebar Lifecycle Flow', () => { vi.clearAllMocks() mockMediaType = MediaType.pc mockInstalledApps = [] - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + mockIsUninstallPending = false }) describe('Pin / Unpin / Delete Flow', () => { @@ -91,9 +107,7 @@ describe('Sidebar Lifecycle Flow', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true }) - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - })) + expect(mockToastSuccess).toHaveBeenCalled() }) // Step 2: Simulate refetch returning pinned state, then unpin @@ -110,9 +124,7 @@ describe('Sidebar Lifecycle Flow', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false }) - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - })) + expect(mockToastSuccess).toHaveBeenCalled() }) }) @@ -136,10 +148,7 @@ describe('Sidebar Lifecycle Flow', () => { // Step 4: Uninstall API called and success toast shown await waitFor(() => { expect(mockUninstall).toHaveBeenCalledWith('app-1') - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - message: 'common.api.remove', - })) + expect(mockToastSuccess).toHaveBeenCalledWith('common.api.remove') }) }) diff --git a/web/app/components/browser-initializer.spec.ts b/web/__tests__/instrumentation-client.spec.ts similarity index 100% rename from web/app/components/browser-initializer.spec.ts rename to web/__tests__/instrumentation-client.spec.ts diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 7ceca4535b..8fa2246198 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -12,8 +12,16 @@ vi.mock('@/config', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockUploadGitHub = vi.fn() @@ -22,33 +30,6 @@ vi.mock('@/service/plugins', () => ({ checkTaskStatus: vi.fn(), })) -vi.mock('@/utils/semver', () => ({ - compareVersion: (a: string, b: string) => { - const parse = (v: string) => v.replace(/^v/, '').split('.').map(Number) - const [aMajor, aMinor = 0, aPatch = 0] = parse(a) - const [bMajor, bMinor = 0, bPatch = 0] = parse(b) - if (aMajor !== bMajor) - return aMajor > bMajor ? 1 : -1 - if (aMinor !== bMinor) - return aMinor > bMinor ? 1 : -1 - if (aPatch !== bPatch) - return aPatch > bPatch ? 1 : -1 - return 0 - }, - getLatestVersion: (versions: string[]) => { - return versions.sort((a, b) => { - const parse = (v: string) => v.replace(/^v/, '').split('.').map(Number) - const [aMaj, aMin = 0, aPat = 0] = parse(a) - const [bMaj, bMin = 0, bPat = 0] = parse(b) - if (aMaj !== bMaj) - return bMaj - aMaj - if (aMin !== bMin) - return bMin - aMin - return bPat - aPat - })[0] - }, -})) - const { useGitHubReleases, useGitHubUpload } = await import( '@/app/components/plugins/install-plugin/hooks', ) diff --git a/web/__tests__/share/text-generation-index-flow.test.tsx b/web/__tests__/share/text-generation-index-flow.test.tsx index 3292474bec..2fec054a47 100644 --- a/web/__tests__/share/text-generation-index-flow.test.tsx +++ b/web/__tests__/share/text-generation-index-flow.test.tsx @@ -5,7 +5,7 @@ import TextGeneration from '@/app/components/share/text-generation' const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => useSearchParamsMock(), })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index fd0bf2c8bd..0c87fd1a4d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -13,8 +13,6 @@ import { RiTerminalWindowLine, } from '@remixicon/react' import { useUnmount } from 'ahooks' -import dynamic from 'next/dynamic' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -26,6 +24,8 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import dynamic from '@/next/dynamic' +import { usePathname, useRouter } from '@/next/navigation' import { fetchAppDetailDirect } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 5e7d98d191..4201d11490 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -7,7 +7,6 @@ import { RiEqualizer2Line, } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -17,6 +16,7 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import Indicator from '@/app/components/header/indicator' import { useAppContext } from '@/context/app-context' +import { usePathname } from '@/next/navigation' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' import { cn } from '@/utils/classnames' import ConfigButton from './config-button' diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 4f3f724e62..730b76ee19 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -9,7 +9,6 @@ import { RiFocus2Fill, RiFocus2Line, } from '@remixicon/react' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -23,6 +22,7 @@ import DatasetDetailContext from '@/context/dataset-detail' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import { usePathname } from '@/next/navigation' import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/datasets/layout.spec.tsx b/web/app/(commonLayout)/datasets/layout.spec.tsx index 5873f344d0..9c01cffba8 100644 --- a/web/app/(commonLayout)/datasets/layout.spec.tsx +++ b/web/app/(commonLayout)/datasets/layout.spec.tsx @@ -6,7 +6,7 @@ import DatasetsLayout from './layout' const mockReplace = vi.fn() const mockUseAppContext = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/(commonLayout)/datasets/layout.tsx b/web/app/(commonLayout)/datasets/layout.tsx index b543c42570..a465f8222b 100644 --- a/web/app/(commonLayout)/datasets/layout.tsx +++ b/web/app/(commonLayout)/datasets/layout.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' import { ExternalApiPanelProvider } from '@/context/external-api-panel-context' import { ExternalKnowledgeApiProvider } from '@/context/external-knowledge-api-context' +import { useRouter } from '@/next/navigation' export default function DatasetsLayout({ children }: { children: React.ReactNode }) { const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext() diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index fce6fe1d5d..44ba5ee8ad 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -1,15 +1,15 @@ 'use client' -import { - useRouter, - useSearchParams, -} from 'next/navigation' import { useEffect, useMemo, } from 'react' import EducationApplyPage from '@/app/education-apply/education-apply-page' import { useProviderContext } from '@/context/provider-context' +import { + useRouter, + useSearchParams, +} from '@/next/navigation' export default function EducationApply() { const router = useRouter() diff --git a/web/app/(commonLayout)/role-route-guard.spec.tsx b/web/app/(commonLayout)/role-route-guard.spec.tsx index 87bf9be8af..ca1550f0b8 100644 --- a/web/app/(commonLayout)/role-route-guard.spec.tsx +++ b/web/app/(commonLayout)/role-route-guard.spec.tsx @@ -6,7 +6,7 @@ const mockReplace = vi.fn() const mockUseAppContext = vi.fn() let mockPathname = '/apps' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, useRouter: () => ({ replace: mockReplace, diff --git a/web/app/(commonLayout)/role-route-guard.tsx b/web/app/(commonLayout)/role-route-guard.tsx index 1c42be9d15..483dfef095 100644 --- a/web/app/(commonLayout)/role-route-guard.tsx +++ b/web/app/(commonLayout)/role-route-guard.tsx @@ -1,10 +1,10 @@ 'use client' import type { ReactNode } from 'react' -import { usePathname, useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' +import { usePathname, useRouter } from '@/next/navigation' const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx index d027ef8b7d..035da6be8a 100644 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ b/web/app/(humanInputLayout)/form/[token]/form.tsx @@ -9,7 +9,6 @@ import { RiInformation2Fill, } from '@remixicon/react' import { produce } from 'immer' -import { useParams } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -21,6 +20,7 @@ import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-inp import Loading from '@/app/components/base/loading' import DifyLogo from '@/app/components/base/logo/dify-logo' import useDocumentTitle from '@/hooks/use-document-title' +import { useParams } from '@/next/navigation' import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' import { cn } from '@/utils/classnames' diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index c874990448..9f956a8501 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -1,12 +1,12 @@ 'use client' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { useGetUserCanAccessApp } from '@/service/access-control' import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' import { webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index a2b847f74f..402005752d 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport, webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index fbf45259e5..1d1c6518fe 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -1,14 +1,14 @@ 'use client' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Countdown from '@/app/components/signin/countdown' - import { useLocale } from '@/context/i18n' + +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -24,17 +24,11 @@ export default function CheckCode() { const verify = async () => { try { if (!code.trim()) { - Toast.notify({ - type: 'error', - message: t('checkCode.emptyCode', { ns: 'login' }), - }) + toast.error(t('checkCode.emptyCode', { ns: 'login' })) return } if (!/\d{6}/.test(code)) { - Toast.notify({ - type: 'error', - message: t('checkCode.invalidCode', { ns: 'login' }), - }) + toast.error(t('checkCode.invalidCode', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 9b9a853cdd..0cdfb4ec11 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,18 +1,18 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' - import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' + +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -27,15 +27,12 @@ export default function CheckCode() { const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - Toast.notify({ - type: 'error', - message: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } setIsLoading(true) @@ -48,16 +45,10 @@ export default function CheckCode() { router.push(`/webapp-reset-password/check-code?${params.toString()}`) } else if (res.code === 'account_not_found') { - Toast.notify({ - type: 'error', - message: t('error.registrationNotAllowed', { ns: 'login' }), - }) + toast.error(t('error.registrationNotAllowed', { ns: 'login' })) } else { - Toast.notify({ - type: 'error', - message: res.data, - }) + toast.error(res.data) } } catch (error) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 9f59e8f9eb..bc8f651d17 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -1,13 +1,13 @@ 'use client' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { validPassword } from '@/config' +import { useRouter, useSearchParams } from '@/next/navigation' import { changeWebAppPasswordWithToken } from '@/service/common' import { cn } from '@/utils/classnames' @@ -24,10 +24,7 @@ const ChangePasswordForm = () => { const [showConfirmPassword, setShowConfirmPassword] = useState(false) const showErrorMessage = useCallback((message: string) => { - Toast.notify({ - type: 'error', - message, - }) + toast.error(message) }, []) const getSignInUrl = () => { diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index afea9d668b..f209ad9e5c 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -1,15 +1,15 @@ 'use client' import type { FormEvent } from 'react' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Countdown from '@/app/components/signin/countdown' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' @@ -43,24 +43,15 @@ export default function CheckCode() { try { const appCode = getAppCodeFromRedirectUrl() if (!code.trim()) { - Toast.notify({ - type: 'error', - message: t('checkCode.emptyCode', { ns: 'login' }), - }) + toast.error(t('checkCode.emptyCode', { ns: 'login' })) return } if (!/\d{6}/.test(code)) { - Toast.notify({ - type: 'error', - message: t('checkCode.invalidCode', { ns: 'login' }), - }) + toast.error(t('checkCode.invalidCode', { ns: 'login' })) return } if (!redirectUrl || !appCode) { - Toast.notify({ - type: 'error', - message: t('error.redirectUrlMissing', { ns: 'login' }), - }) + toast.error(t('error.redirectUrlMissing', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx index 0776df036d..9b4a369908 100644 --- a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' @@ -17,10 +17,7 @@ const ExternalMemberSSOAuth = () => { const redirectUrl = searchParams.get('redirect_url') const showErrorToast = (message: string) => { - Toast.notify({ - type: 'error', - message, - }) + toast.error(message) } const getAppCodeFromRedirectUrl = useCallback(() => { diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 5aa9d9f141..fbd6b216df 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,13 +1,13 @@ import { noop } from 'es-toolkit/function' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { @@ -22,15 +22,12 @@ export default function MailAndCodeAuth() { const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - Toast.notify({ - type: 'error', - message: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index e49559401d..1e9355e7ba 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,15 +1,15 @@ 'use client' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' @@ -46,26 +46,20 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { if (!email) { - Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - Toast.notify({ - type: 'error', - message: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } if (!password?.trim()) { - Toast.notify({ type: 'error', message: t('error.passwordEmpty', { ns: 'login' }) }) + toast.error(t('error.passwordEmpty', { ns: 'login' })) return } if (!redirectUrl || !appCode) { - Toast.notify({ - type: 'error', - message: t('error.redirectUrlMissing', { ns: 'login' }), - }) + toast.error(t('error.redirectUrlMissing', { ns: 'login' })) return } try { @@ -94,15 +88,12 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut router.replace(decodeURIComponent(redirectUrl)) } else { - Toast.notify({ - type: 'error', - message: res.data, - }) + toast.error(res.data) } } catch (e: any) { if (e.code === 'authentication_failed') - Toast.notify({ type: 'error', message: e.message }) + toast.error(e.message) } finally { setIsLoading(false) diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index d8f3854868..3178c638cc 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' @@ -37,10 +37,7 @@ const SSOAuth: FC = ({ const handleSSOLogin = () => { const appCode = getAppCodeFromRedirectUrl() if (!redirectUrl || !appCode) { - Toast.notify({ - type: 'error', - message: 'invalid redirect URL or app code', - }) + toast.error(t('error.invalidRedirectUrlOrAppCode', { ns: 'login' })) return } setIsLoading(true) @@ -66,10 +63,7 @@ const SSOAuth: FC = ({ }) } else { - Toast.notify({ - type: 'error', - message: 'invalid SSO protocol', - }) + toast.error(t('error.invalidSSOProtocol', { ns: 'login' })) setIsLoading(false) } } diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index b15145346f..7ee08d66ae 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -1,12 +1,12 @@ 'use client' import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' +import Link from '@/next/link' import { LicenseStatus } from '@/types/feature' import { cn } from '@/utils/classnames' import MailAndCodeAuth from './components/mail-and-code-auth' diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index b3ad1d48a6..a5c2528cc7 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -8,6 +7,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable' import { useGlobalPublicStore } from '@/context/global-public-context' import { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogout } from '@/service/webapp-auth' import ExternalMemberSsoAuth from './components/external-member-sso-auth' import NormalForm from './normalForm' diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index c146174ea9..f0dfd4f12f 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -1,7 +1,6 @@ import type { ResponseError } from '@/service/fetch' import { RiCloseLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { Trans, useTranslation } from 'react-i18next' @@ -10,6 +9,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { ToastContext } from '@/app/components/base/toast/context' +import { useRouter } from '@/next/navigation' import { checkEmailExisted, resetEmail, diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 07b685b8c5..0b3541ae9c 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -3,7 +3,6 @@ import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/r import { RiGraduationCapFill, } from '@remixicon/react' -import { useRouter } from 'next/navigation' import { Fragment } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' @@ -11,6 +10,7 @@ import { Avatar } from '@/app/components/base/avatar' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { useLogout, useUserProfile } from '@/service/use-common' export type IAppSelector = { diff --git a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx index 65a58c936e..e0f00189b2 100644 --- a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx @@ -1,10 +1,10 @@ 'use client' -import Link from 'next/link' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import { useAppContext } from '@/context/app-context' +import Link from '@/next/link' import { useSendDeleteAccountEmail } from '../state' type DeleteAccountProps = { diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index 67fea3c141..ae73d778f8 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -1,5 +1,4 @@ 'use client' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -7,6 +6,7 @@ import CustomDialog from '@/app/components/base/dialog' import Textarea from '@/app/components/base/textarea' import Toast from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' +import { useRouter } from '@/next/navigation' import { useLogout } from '@/service/use-common' import { useDeleteAccountFeedback } from '../state' diff --git a/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx b/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx index d7590c27f9..5d76f13f34 100644 --- a/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx @@ -1,10 +1,10 @@ 'use client' -import Link from 'next/link' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Countdown from '@/app/components/signin/countdown' +import Link from '@/next/link' import { useAccountDeleteStore, useConfirmDeleteAccount, useSendDeleteAccountEmail } from '../state' const CODE_EXP = /[A-Z\d]{6}/gi diff --git a/web/app/account/(commonLayout)/header.tsx b/web/app/account/(commonLayout)/header.tsx index bb58be87a8..5ef84a8f1e 100644 --- a/web/app/account/(commonLayout)/header.tsx +++ b/web/app/account/(commonLayout)/header.tsx @@ -1,11 +1,11 @@ 'use client' import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react' -import { useRouter } from 'next/navigation' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import DifyLogo from '@/app/components/base/logo/dify-logo' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useRouter } from '@/next/navigation' import Avatar from './avatar' const Header = () => { diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 835a1e702e..670f6ec593 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -7,16 +7,16 @@ import { RiMailLine, RiTranslate2, } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' +import { useRouter, useSearchParams } from '@/next/navigation' import { useIsLogin, useUserProfile } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' @@ -91,10 +91,7 @@ export default function OAuthAuthorize() { globalThis.location.href = url.toString() } catch (err: any) { - Toast.notify({ - type: 'error', - message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, - }) + toast.error(`${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`) } } @@ -102,11 +99,10 @@ export default function OAuthAuthorize() { const invalidParams = !client_id || !redirect_uri if ((invalidParams || isError) && !hasNotifiedRef.current) { hasNotifiedRef.current = true - Toast.notify({ - type: 'error', - message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), - duration: 0, - }) + toast.error( + invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), + { timeout: 0 }, + ) } }, [client_id, redirect_uri, isError]) diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 421b816652..418d3b8bb1 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import { useEffect } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' - import useDocumentTitle from '@/hooks/use-document-title' + +import { useRouter, useSearchParams } from '@/next/navigation' import { useInvitationCheck } from '@/service/use-common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index bf7aa39580..e08ece6666 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -2,13 +2,13 @@ import type { ReactNode } from 'react' import Cookies from 'js-cookie' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import { parseAsBoolean, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' diff --git a/web/app/components/app-sidebar/__tests__/index.spec.tsx b/web/app/components/app-sidebar/__tests__/index.spec.tsx index 89db80e0f1..b2e1e92bbb 100644 --- a/web/app/components/app-sidebar/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/index.spec.tsx @@ -19,7 +19,7 @@ vi.mock('zustand/react/shallow', () => ({ useShallow: (fn: unknown) => fn, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, })) diff --git a/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx index fb19833dd2..a3868a8330 100644 --- a/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx @@ -7,7 +7,7 @@ import { render } from '@testing-library/react' import * as React from 'react' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx index f8612e8057..2f98089e40 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { AppModeEnum } from '@/types/app' import AppInfoModals from '../app-info-modals' -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComp = React.lazy(loader) return function DynamicWrapper(props: Record) { diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index 6104e2b641..deea28ce3e 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -23,7 +23,7 @@ let mockAppDetail: Record | undefined = { icon_background: '#FFEAD5', } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) diff --git a/web/app/components/app-sidebar/app-info/app-info-modals.tsx b/web/app/components/app-sidebar/app-info/app-info-modals.tsx index 4ca7f6adbc..6b76be87bb 100644 --- a/web/app/components/app-sidebar/app-info/app-info-modals.tsx +++ b/web/app/components/app-sidebar/app-info/app-info-modals.tsx @@ -3,9 +3,10 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-moda import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' import type { App, AppSSO } from '@/types/app' -import dynamic from 'next/dynamic' import * as React from 'react' +import { useState } from 'react' import { useTranslation } from 'react-i18next' +import dynamic from '@/next/dynamic' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false }) const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { ssr: false }) @@ -42,6 +43,7 @@ const AppInfoModals = ({ onConfirmDelete, }: AppInfoModalsProps) => { const { t } = useTranslation() + const [confirmDeleteInput, setConfirmDeleteInput] = useState('') return ( <> @@ -88,8 +90,16 @@ const AppInfoModals = ({ title={t('deleteAppConfirmTitle', { ns: 'app' })} content={t('deleteAppConfirmContent', { ns: 'app' })} isShow + confirmInputLabel={t('deleteAppConfirmInputLabel', { ns: 'app', appName: appDetail.name })} + confirmInputPlaceholder={t('deleteAppConfirmInputPlaceholder', { ns: 'app' })} + confirmInputValue={confirmDeleteInput} + onConfirmInputChange={setConfirmDeleteInput} + confirmInputMatchValue={appDetail.name} onConfirm={onConfirmDelete} - onCancel={closeModal} + onCancel={() => { + setConfirmDeleteInput('') + closeModal() + }} /> )} {activeModal === 'importDSL' && ( diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 800f21de44..55ec13e506 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -1,7 +1,6 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -9,6 +8,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' import { useInvalidateAppList } from '@/service/use-apps' import { fetchWorkflowDraft } from '@/service/workflow' diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx index 512f9490c2..1df6fa79b7 100644 --- a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx @@ -80,7 +80,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx index be27e247d7..a1e275d731 100644 --- a/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx @@ -90,7 +90,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 96127c4210..528bac831f 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -1,11 +1,11 @@ import type { DataSet } from '@/models/datasets' import { RiMoreFill } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { useRouter } from '@/next/navigation' import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { useInvalid } from '@/service/use-base' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index e24b005d01..13fde97f89 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -1,12 +1,12 @@ import type { NavIcon } from './nav-link' import { useHover, useKeyPress } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { usePathname } from '@/next/navigation' import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { getKeyboardKeyCodeBySystem } from '../workflow/utils' diff --git a/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx index 04ca7bd0e4..fe46290002 100644 --- a/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx @@ -4,12 +4,12 @@ import * as React from 'react' import NavLink from '..' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock Next.js Link component -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: function MockLink({ children, href, className, title }: { children: React.ReactNode, href: string, className?: string, title?: string }) { return ( diff --git a/web/app/components/app-sidebar/nav-link/index.tsx b/web/app/components/app-sidebar/nav-link/index.tsx index d69ed8590e..cf986a7407 100644 --- a/web/app/components/app-sidebar/nav-link/index.tsx +++ b/web/app/components/app-sidebar/nav-link/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { RemixiconComponentType } from '@remixicon/react' -import Link from 'next/link' -import { useSelectedLayoutSegment } from 'next/navigation' import * as React from 'react' +import Link from '@/next/link' +import { useSelectedLayoutSegment } from '@/next/navigation' import { cn } from '@/utils/classnames' export type NavIcon = React.ComponentType< diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index f5ebaac3ca..8ad284bcfb 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -298,7 +298,6 @@ const GetAutomaticRes: FC = ({
= (
({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx index aa8dae813f..6704fa0afd 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import VarPicker from './var-picker' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index d2e4913e54..6dd03d217e 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -370,7 +370,6 @@ const ConfigContent: FC = ({ const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction -let toastNotifySpy: MockInstance +let toastErrorSpy: MockInstance const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { return { @@ -140,7 +140,7 @@ describe('dataset-config/params-config', () => { beforeEach(() => { vi.clearAllMocks() vi.useRealTimers() - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + toastErrorSpy = vi.spyOn(toast, 'error').mockImplementation(() => '') mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ modelList: [], defaultModel: undefined, @@ -154,7 +154,7 @@ describe('dataset-config/params-config', () => { }) afterEach(() => { - toastNotifySpy.mockRestore() + toastErrorSpy.mockRestore() }) // Rendering tests (REQUIRED) @@ -254,10 +254,7 @@ describe('dataset-config/params-config', () => { await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'appDebug.datasetConfig.rerankModelRequired', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('appDebug.datasetConfig.rerankModelRequired') expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 692ae12022..89410203df 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { @@ -66,10 +66,7 @@ const ParamsConfig = ({ } } if (errMsg) { - Toast.notify({ - type: 'error', - message: errMsg, - }) + toast.error(errMsg) } return !errMsg } diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 91e5353cc4..8c2fb77c20 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import type { DataSet } from '@/models/datasets' import { useInfiniteScroll } from 'ahooks' -import Link from 'next/link' import * as React from 'react' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -14,6 +13,7 @@ import Modal from '@/app/components/base/modal' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import FeatureIcon from '@/app/components/header/account-setting/model-provider-page/model-selector/feature-icon' import { useKnowledge } from '@/hooks/use-knowledge' +import Link from '@/next/link' import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx index 48141d0045..a75516a43f 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -155,7 +155,7 @@ vi.mock('@/service/debug', () => ({ stopChatMessageResponding: mockStopChatMessageResponding, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', useParams: () => ({}), diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 0e6ffb1e84..aa1bbe0a16 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -23,7 +23,6 @@ import { useBoolean, useGetState } from 'ahooks' import { clone } from 'es-toolkit/object' import { isEqual } from 'es-toolkit/predicate' import { produce } from 'immer' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -72,6 +71,7 @@ import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { PromptMode } from '@/models/debug' +import { usePathname } from '@/next/navigation' import { fetchAppDetailDirect, updateAppModelConfig } from '@/service/apps' import { fetchDatasets } from '@/service/datasets' import { fetchCollectionList } from '@/service/tools' diff --git a/web/app/components/app/create-app-dialog/app-list/index.spec.tsx b/web/app/components/app/create-app-dialog/app-list/index.spec.tsx index e0f459ee75..a9b65a4ae9 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.spec.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.spec.tsx @@ -39,8 +39,8 @@ vi.mock('../app-card', () => ({ vi.mock('@/app/components/explore/create-app-modal', () => ({ default: () =>
, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { add: vi.fn() }, })) vi.mock('@/app/components/base/amplitude', () => ({ trackEvent: vi.fn(), @@ -62,7 +62,7 @@ vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ vi.mock('@/utils/app-redirection', () => ({ getRedirection: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), })) diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index b967ba7d55..1aa40d2014 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -4,7 +4,6 @@ import type { CreateAppModalProps } from '@/app/components/explore/create-app-mo import type { App } from '@/models/explore' import { RiRobot2Line } from '@remixicon/react' import { useDebounceFn } from 'ahooks' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -13,12 +12,13 @@ import { trackEvent } from '@/app/components/base/amplitude' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import CreateAppModal from '@/app/components/explore/create-app-modal' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { DSLImportMode } from '@/models/app' +import { useRouter } from '@/next/navigation' import { importDSL } from '@/service/apps' import { fetchAppDetail } from '@/service/explore' import { useExploreAppList } from '@/service/use-explore' @@ -137,10 +137,7 @@ const Apps = ({ }) setIsShowCreateModal(false) - Toast.notify({ - type: 'success', - message: t('newApp.appCreated', { ns: 'app' }), - }) + toast.success(t('newApp.appCreated', { ns: 'app' })) if (onSuccess) onSuccess() if (app.app_id) @@ -149,7 +146,7 @@ const Apps = ({ getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push) } catch { - Toast.notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index a9adb17582..c253fcd457 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -1,13 +1,13 @@ import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { useRouter } from 'next/navigation' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' import { trackEvent } from '@/app/components/base/amplitude' - import { ToastContext } from '@/app/components/base/toast/context' + import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' @@ -23,7 +23,7 @@ vi.mock('ahooks', () => ({ useKeyPress: vi.fn(), useHover: () => false, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(), })) vi.mock('@/app/components/base/amplitude', () => ({ diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 1c22913bb1..556773c341 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -4,7 +4,6 @@ import type { AppIconSelection } from '../../base/app-icon-picker' import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' -import { useRouter } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -22,6 +21,7 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import useTheme from '@/hooks/use-theme' +import { useRouter } from '@/next/navigation' import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index a0c8360c29..eaaee50973 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -4,7 +4,6 @@ import type { MouseEventHandler } from 'react' import { RiCloseLine } from '@remixicon/react' import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -22,6 +21,7 @@ import { DSLImportMode, DSLImportStatus, } from '@/models/app' +import { useRouter } from '@/next/navigation' import { importDSL, importDSLConfirm, diff --git a/web/app/components/app/log-annotation/index.spec.tsx b/web/app/components/app/log-annotation/index.spec.tsx index c7c654e870..de33ae6f66 100644 --- a/web/app/components/app/log-annotation/index.spec.tsx +++ b/web/app/components/app/log-annotation/index.spec.tsx @@ -7,7 +7,7 @@ import { AppModeEnum } from '@/types/app' import LogAnnotation from './index' const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), diff --git a/web/app/components/app/log-annotation/index.tsx b/web/app/components/app/log-annotation/index.tsx index ca6182603d..c5c21289df 100644 --- a/web/app/components/app/log-annotation/index.tsx +++ b/web/app/components/app/log-annotation/index.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC } from 'react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' @@ -11,6 +10,7 @@ import WorkflowLog from '@/app/components/app/workflow-log' import { PageType } from '@/app/components/base/features/new-feature-panel/annotation-reply/type' import Loading from '@/app/components/base/loading' import TabSlider from '@/app/components/base/tab-slider-plain' +import { useRouter } from '@/next/navigation' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/log/empty-element.tsx b/web/app/components/app/log/empty-element.tsx index e42a1df7d5..95b0e7f03f 100644 --- a/web/app/components/app/log/empty-element.tsx +++ b/web/app/components/app/log/empty-element.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC, SVGProps } from 'react' import type { App } from '@/types/app' -import Link from 'next/link' import * as React from 'react' import { Trans, useTranslation } from 'react-i18next' +import Link from '@/next/link' import { AppModeEnum } from '@/types/app' import { getRedirectionPath } from '@/utils/app-redirection' import { basePath } from '@/utils/var' diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index e96c9ce0c9..59f454f754 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -4,13 +4,13 @@ import type { App } from '@/types/app' import { useDebounce } from 'ahooks' import dayjs from 'dayjs' import { omit } from 'es-toolkit/object' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import Pagination from '@/app/components/base/pagination' import { APP_PAGE_LIMIT } from '@/config' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { useChatConversations, useCompletionConversations } from '@/service/use-log' import { AppModeEnum } from '@/types/app' import EmptyElement from './empty-element' diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 146af44a10..453c7c9d4c 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -14,7 +14,6 @@ import timezone from 'dayjs/plugin/timezone' import utc from 'dayjs/plugin/utc' import { get } from 'es-toolkit/compat' import { noop } from 'es-toolkit/function' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -38,6 +37,7 @@ import { WorkflowContextProvider } from '@/app/components/workflow/context' import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' import { AppSourceType } from '@/service/share' import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log' diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index 1b02e54d5f..42cf4d8618 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -14,7 +14,6 @@ import { RiVerifiedBadgeLine, RiWindowLine, } from '@remixicon/react' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -34,6 +33,7 @@ import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useDocLink } from '@/context/i18n' import { AccessMode } from '@/models/access-control' +import { usePathname, useRouter } from '@/next/navigation' import { useAppWhiteListSubjects } from '@/service/access-control' import { fetchAppDetailDirect } from '@/service/apps' import { useAppWorkflow } from '@/service/use-workflow' diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index f7c9e309ab..13dacde424 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -4,7 +4,6 @@ import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import type { AppDetailResponse } from '@/models/app' import type { AppIconType, AppSSO, Language } from '@/types/app' import { RiArrowRightSLine, RiCloseLine } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback, useEffect, useRef, useState } from 'react' import { Trans, useTranslation } from 'react-i18next' @@ -26,6 +25,7 @@ import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/con import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { languages } from '@/i18n-config/language' +import Link from '@/next/link' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/overview/trigger-card.tsx b/web/app/components/app/overview/trigger-card.tsx index 1f0f0dca56..09e3a08393 100644 --- a/web/app/components/app/overview/trigger-card.tsx +++ b/web/app/components/app/overview/trigger-card.tsx @@ -3,7 +3,6 @@ import type { AppDetailResponse } from '@/models/app' import type { AppTrigger } from '@/service/use-tools' import type { AppSSO } from '@/types/app' import type { I18nKeysByPrefix } from '@/types/i18n' -import Link from 'next/link' import * as React from 'react' import { useTranslation } from 'react-i18next' import { TriggerAll } from '@/app/components/base/icons/src/vender/workflow' @@ -13,6 +12,7 @@ import { useTriggerStatusStore } from '@/app/components/workflow/store/trigger-s import { BlockEnum } from '@/app/components/workflow/types' import { useAppContext } from '@/context/app-context' import { useDocLink } from '@/context/i18n' +import Link from '@/next/link' import { useAppTriggers, diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx index c905d79b31..53007b986b 100644 --- a/web/app/components/app/switch-app-modal/index.spec.tsx +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -11,7 +11,7 @@ import SwitchAppModal from './index' const mockPush = vi.fn() const mockReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, replace: mockReplace, diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index 8caa07c187..7c3269d52c 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -3,7 +3,6 @@ import type { App } from '@/types/app' import { RiCloseLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -20,6 +19,7 @@ import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { deleteApp, switchApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 22358805a7..d22375a292 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -16,7 +16,6 @@ import { } from '@remixicon/react' import { useBoolean } from 'ahooks' import copy from 'copy-to-clipboard' -import { useParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -30,6 +29,7 @@ import Loading from '@/app/components/base/loading' import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' import Toast from '@/app/components/base/toast' +import { useParams } from '@/next/navigation' import { fetchTextGenerationMessage } from '@/service/debug' import { AppSourceType, fetchMoreLikeThis, submitHumanInputForm, updateFeedback } from '@/service/share' import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' diff --git a/web/app/components/app/text-generate/saved-items/index.spec.tsx b/web/app/components/app/text-generate/saved-items/index.spec.tsx index f04a37bded..b45a1cca6c 100644 --- a/web/app/components/app/text-generate/saved-items/index.spec.tsx +++ b/web/app/components/app/text-generate/saved-items/index.spec.tsx @@ -10,7 +10,7 @@ import SavedItems from './index' vi.mock('copy-to-clipboard', () => ({ default: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({}), usePathname: () => '/', })) diff --git a/web/app/components/app/type-selector/index.spec.tsx b/web/app/components/app/type-selector/index.spec.tsx index e24d963305..711678f0a8 100644 --- a/web/app/components/app/type-selector/index.spec.tsx +++ b/web/app/components/app/type-selector/index.spec.tsx @@ -1,4 +1,4 @@ -import { fireEvent, render, screen, within } from '@testing-library/react' +import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' import * as React from 'react' import { AppModeEnum } from '@/types/app' import AppTypeSelector, { AppTypeIcon, AppTypeLabel } from './index' @@ -14,7 +14,7 @@ describe('AppTypeSelector', () => { render() expect(screen.getByText('app.typeSelector.all')).toBeInTheDocument() - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.workflow')).not.toBeInTheDocument() }) }) @@ -39,24 +39,27 @@ describe('AppTypeSelector', () => { // Covers opening/closing the dropdown and selection updates. describe('User interactions', () => { - it('should toggle option list when clicking the trigger', () => { + it('should close option list when clicking outside', () => { render() - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + expect(screen.queryByRole('list')).not.toBeInTheDocument() - fireEvent.click(screen.getByText('app.typeSelector.all')) - expect(screen.getByRole('tooltip')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.all' })) + expect(screen.getByRole('list')).toBeInTheDocument() - fireEvent.click(screen.getByText('app.typeSelector.all')) - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + fireEvent.pointerDown(document.body) + fireEvent.click(document.body) + return waitFor(() => { + expect(screen.queryByRole('list')).not.toBeInTheDocument() + }) }) it('should call onChange with added type when selecting an unselected item', () => { const onChange = vi.fn() render() - fireEvent.click(screen.getByText('app.typeSelector.all')) - fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.all' })) + fireEvent.click(within(screen.getByRole('list')).getByRole('button', { name: 'app.typeSelector.workflow' })) expect(onChange).toHaveBeenCalledWith([AppModeEnum.WORKFLOW]) }) @@ -65,8 +68,8 @@ describe('AppTypeSelector', () => { const onChange = vi.fn() render() - fireEvent.click(screen.getByText('app.typeSelector.workflow')) - fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.workflow' })) + fireEvent.click(within(screen.getByRole('list')).getByRole('button', { name: 'app.typeSelector.workflow' })) expect(onChange).toHaveBeenCalledWith([]) }) @@ -75,8 +78,8 @@ describe('AppTypeSelector', () => { const onChange = vi.fn() render() - fireEvent.click(screen.getByText('app.typeSelector.chatbot')) - fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.agent')) + fireEvent.click(screen.getByRole('button', { name: 'app.typeSelector.chatbot' })) + fireEvent.click(within(screen.getByRole('list')).getByRole('button', { name: 'app.typeSelector.agent' })) expect(onChange).toHaveBeenCalledWith([AppModeEnum.CHAT, AppModeEnum.AGENT_CHAT]) }) @@ -88,7 +91,7 @@ describe('AppTypeSelector', () => { fireEvent.click(screen.getByRole('button', { name: 'common.operation.clear' })) expect(onChange).toHaveBeenCalledWith([]) - expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.workflow')).not.toBeInTheDocument() }) }) }) diff --git a/web/app/components/app/type-selector/index.tsx b/web/app/components/app/type-selector/index.tsx index e97da4b7f3..a1475f9eff 100644 --- a/web/app/components/app/type-selector/index.tsx +++ b/web/app/components/app/type-selector/index.tsx @@ -4,13 +4,12 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/base/icons/src/vender/solid/communication' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' -import Checkbox from '../../base/checkbox' export type AppSelectorProps = { value: Array @@ -22,43 +21,43 @@ const allTypes: AppModeEnum[] = [AppModeEnum.WORKFLOW, AppModeEnum.ADVANCED_CHAT const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => { const [open, setOpen] = useState(false) const { t } = useTranslation() + const triggerLabel = value.length === 0 + ? t('typeSelector.all', { ns: 'app' }) + : value.map(type => getAppTypeLabel(type, t)).join(', ') return ( -
- setOpen(v => !v)} - className="block" - > -
0 && 'pr-7', )} + > + + + {value.length > 0 && ( + - )} -
-
- -
    + + + )} + +
      {allTypes.map(mode => ( { /> ))}
    - +
-
+ ) } @@ -173,33 +172,54 @@ type AppTypeSelectorItemProps = { } function AppTypeSelectorItem({ checked, type, onClick }: AppTypeSelectorItemProps) { return ( -
  • - - -
    - -
    +
  • +
  • ) } +function getAppTypeLabel(type: AppModeEnum, t: ReturnType['t']) { + if (type === AppModeEnum.CHAT) + return t('typeSelector.chatbot', { ns: 'app' }) + if (type === AppModeEnum.AGENT_CHAT) + return t('typeSelector.agent', { ns: 'app' }) + if (type === AppModeEnum.COMPLETION) + return t('typeSelector.completion', { ns: 'app' }) + if (type === AppModeEnum.ADVANCED_CHAT) + return t('typeSelector.advanced', { ns: 'app' }) + if (type === AppModeEnum.WORKFLOW) + return t('typeSelector.workflow', { ns: 'app' }) + + return '' +} + type AppTypeLabelProps = { type: AppModeEnum className?: string } export function AppTypeLabel({ type, className }: AppTypeLabelProps) { const { t } = useTranslation() - let label = '' - if (type === AppModeEnum.CHAT) - label = t('typeSelector.chatbot', { ns: 'app' }) - if (type === AppModeEnum.AGENT_CHAT) - label = t('typeSelector.agent', { ns: 'app' }) - if (type === AppModeEnum.COMPLETION) - label = t('typeSelector.completion', { ns: 'app' }) - if (type === AppModeEnum.ADVANCED_CHAT) - label = t('typeSelector.advanced', { ns: 'app' }) - if (type === AppModeEnum.WORKFLOW) - label = t('typeSelector.workflow', { ns: 'app' }) - return {label} + return {getAppTypeLabel(type, t)} } diff --git a/web/app/components/app/workflow-log/detail.spec.tsx b/web/app/components/app/workflow-log/detail.spec.tsx index 1ed7193d42..806c6e71b2 100644 --- a/web/app/components/app/workflow-log/detail.spec.tsx +++ b/web/app/components/app/workflow-log/detail.spec.tsx @@ -19,7 +19,7 @@ import DetailPanel from './detail' // ============================================================================ const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), diff --git a/web/app/components/app/workflow-log/detail.tsx b/web/app/components/app/workflow-log/detail.tsx index ce85653e71..99d2c70228 100644 --- a/web/app/components/app/workflow-log/detail.tsx +++ b/web/app/components/app/workflow-log/detail.tsx @@ -1,12 +1,12 @@ 'use client' import type { FC } from 'react' import { RiCloseLine, RiPlayLargeLine } from '@remixicon/react' -import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' import { useStore } from '@/app/components/app/store' import TooltipPlus from '@/app/components/base/tooltip' import { WorkflowContextProvider } from '@/app/components/workflow/context' import Run from '@/app/components/workflow/run' +import { useRouter } from '@/next/navigation' type ILogDetail = { runID: string diff --git a/web/app/components/app/workflow-log/index.spec.tsx b/web/app/components/app/workflow-log/index.spec.tsx index f8e3f16e25..92f8eddf83 100644 --- a/web/app/components/app/workflow-log/index.spec.tsx +++ b/web/app/components/app/workflow-log/index.spec.tsx @@ -47,13 +47,13 @@ vi.mock('ahooks', () => ({ }, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn(), }), })) -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href }: { children: React.ReactNode, href: string }) =>
    {children}, })) diff --git a/web/app/components/app/workflow-log/list.spec.tsx b/web/app/components/app/workflow-log/list.spec.tsx index 760d222692..36cc911248 100644 --- a/web/app/components/app/workflow-log/list.spec.tsx +++ b/web/app/components/app/workflow-log/list.spec.tsx @@ -23,7 +23,7 @@ import WorkflowAppLogList from './list' // ============================================================================ const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), diff --git a/web/app/components/apps/__tests__/app-card.spec.tsx b/web/app/components/apps/__tests__/app-card.spec.tsx index 9bc23ce199..86c87e0c5b 100644 --- a/web/app/components/apps/__tests__/app-card.spec.tsx +++ b/web/app/components/apps/__tests__/app-card.spec.tsx @@ -11,7 +11,7 @@ import AppCard from '../app-card' // Mock next/navigation const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), @@ -111,7 +111,7 @@ vi.mock('@/utils/time', () => ({ })) // Mock dynamic imports -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (importFn: () => Promise) => { const fnString = importFn.toString() @@ -543,6 +543,11 @@ describe('AppCard', () => { fireEvent.click(screen.getByTestId('popover-trigger')) fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + + // Fill in the confirmation input with app name + const deleteInput = screen.getByRole('textbox') + fireEvent.change(deleteInput, { target: { value: mockApp.name } }) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { @@ -556,6 +561,11 @@ describe('AppCard', () => { fireEvent.click(screen.getByTestId('popover-trigger')) fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + + // Fill in the confirmation input with app name + const deleteInput = screen.getByRole('textbox') + fireEvent.change(deleteInput, { target: { value: mockApp.name } }) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { @@ -572,6 +582,11 @@ describe('AppCard', () => { fireEvent.click(screen.getByTestId('popover-trigger')) fireEvent.click(await screen.findByRole('button', { name: 'common.operation.delete' })) expect(await screen.findByRole('alertdialog')).toBeInTheDocument() + + // Fill in the confirmation input with app name + const deleteInput = screen.getByRole('textbox') + fireEvent.change(deleteInput, { target: { value: mockApp.name } }) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) await waitFor(() => { diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 989bf6a788..877c392e6d 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -8,7 +8,7 @@ import List from '../list' const mockReplace = vi.fn() const mockRouter = { replace: mockReplace } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => mockRouter, useSearchParams: () => new URLSearchParams(''), })) @@ -124,7 +124,7 @@ vi.mock('@/hooks/use-pay', () => ({ CheckModal: () => null, })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (importFn: () => Promise) => { const fnString = importFn.toString() diff --git a/web/app/components/apps/__tests__/new-app-card.spec.tsx b/web/app/components/apps/__tests__/new-app-card.spec.tsx index f4c357b9f9..9c98936bea 100644 --- a/web/app/components/apps/__tests__/new-app-card.spec.tsx +++ b/web/app/components/apps/__tests__/new-app-card.spec.tsx @@ -4,7 +4,7 @@ import * as React from 'react' import CreateAppCard from '../new-app-card' const mockReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), @@ -18,7 +18,7 @@ vi.mock('@/context/provider-context', () => ({ }), })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (importFn: () => Promise<{ default: React.ComponentType }>) => { const fnString = importFn.toString() diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 471b3420d1..9a8abf6443 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -7,8 +7,6 @@ import type { CreateAppModalProps } from '@/app/components/explore/create-app-mo import type { EnvironmentVariable } from '@/app/components/workflow/types' import type { App } from '@/types/app' import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react' -import dynamic from 'next/dynamic' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -36,6 +34,8 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' +import dynamic from '@/next/dynamic' +import { useRouter } from '@/next/navigation' import { useGetUserCanAccessApp } from '@/service/access-control' import { copyApp, exportAppConfig, updateAppInfo } from '@/service/apps' import { fetchInstalledAppList } from '@/service/explore' @@ -82,6 +82,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const [showDuplicateModal, setShowDuplicateModal] = useState(false) const [showSwitchModal, setShowSwitchModal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) + const [confirmDeleteInput, setConfirmDeleteInput] = useState('') const [showAccessControl, setShowAccessControl] = useState(false) const [secretEnvList, setSecretEnvList] = useState([]) const { mutateAsync: mutateDeleteApp, isPending: isDeleting } = useDeleteAppMutation() @@ -100,6 +101,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } finally { setShowConfirmDelete(false) + setConfirmDeleteInput('') } }, [app.id, mutateDeleteApp, notify, onPlanInfoChanged, t]) @@ -108,6 +110,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { return setShowConfirmDelete(open) + if (!open) + setConfirmDeleteInput('') }, [isDeleting]) const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ @@ -521,12 +525,28 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { {t('deleteAppConfirmContent', { ns: 'app' })} +
    + + setConfirmDeleteInput(e.target.value)} + /> +
    {t('operation.cancel', { ns: 'common' })} - + {t('operation.confirm', { ns: 'common' })} diff --git a/web/app/components/apps/footer.tsx b/web/app/components/apps/footer.tsx index 3a0e960e0d..9147ccf6a6 100644 --- a/web/app/components/apps/footer.tsx +++ b/web/app/components/apps/footer.tsx @@ -1,7 +1,7 @@ import { RiDiscordFill, RiDiscussLine, RiGithubFill } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useTranslation } from 'react-i18next' +import Link from '@/next/link' type CustomLinkProps = { href: string diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index dce9de190d..b6ca60bd7b 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -8,12 +8,14 @@ import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' +import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' -import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' -import CreateAppModal from '../explore/create-app-modal' -import TryApp from '../explore/try-app' import List from './list' +const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) +const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) + const Apps = () => { const { t } = useTranslation() diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 6ae422f716..2ef344f816 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -2,19 +2,19 @@ import type { FC } from 'react' import { useDebounceFn } from 'ahooks' -import dynamic from 'next/dynamic' import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagFilter from '@/app/components/base/tag-management/filter' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' -import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { CheckModal } from '@/hooks/use-pay' +import dynamic from '@/next/dynamic' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum, AppModes } from '@/types/app' import { cn } from '@/utils/classnames' @@ -205,12 +205,12 @@ const List: FC = ({ options={options} />
    - + import('@/app/components/app/create-app-modal'), { diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx index e1d8e52eac..00af15e24d 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.tsx +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -5,17 +5,12 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import * as React from 'react' import { useEffect } from 'react' -import { AMPLITUDE_API_KEY, IS_CLOUD_EDITION } from '@/config' +import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config' export type IAmplitudeProps = { sessionReplaySampleRate?: number } -// Check if Amplitude should be enabled -export const isAmplitudeEnabled = () => { - return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY -} - // Map URL pathname to English page name for consistent Amplitude tracking const getEnglishPageName = (pathname: string): string => { // Remove leading slash and get the first segment @@ -59,7 +54,7 @@ const AmplitudeProvider: FC = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return // Initialize Amplitude diff --git a/web/app/components/base/amplitude/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx similarity index 87% rename from web/app/components/base/amplitude/AmplitudeProvider.spec.tsx rename to web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx index 2402c84a3e..5835634eb7 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.spec.tsx +++ b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx @@ -2,14 +2,24 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import { render } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider' +import AmplitudeProvider from '../AmplitudeProvider' const mockConfig = vi.hoisted(() => ({ AMPLITUDE_API_KEY: 'test-api-key', IS_CLOUD_EDITION: true, })) -vi.mock('@/config', () => mockConfig) +vi.mock('@/config', () => ({ + get AMPLITUDE_API_KEY() { + return mockConfig.AMPLITUDE_API_KEY + }, + get IS_CLOUD_EDITION() { + return mockConfig.IS_CLOUD_EDITION + }, + get isAmplitudeEnabled() { + return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY + }, +})) vi.mock('@amplitude/analytics-browser', () => ({ init: vi.fn(), @@ -27,22 +37,6 @@ describe('AmplitudeProvider', () => { mockConfig.IS_CLOUD_EDITION = true }) - describe('isAmplitudeEnabled', () => { - it('returns true when cloud edition and api key present', () => { - expect(isAmplitudeEnabled()).toBe(true) - }) - - it('returns false when cloud edition but no api key', () => { - mockConfig.AMPLITUDE_API_KEY = '' - expect(isAmplitudeEnabled()).toBe(false) - }) - - it('returns false when not cloud edition', () => { - mockConfig.IS_CLOUD_EDITION = false - expect(isAmplitudeEnabled()).toBe(false) - }) - }) - describe('Component', () => { it('initializes amplitude when enabled', () => { render() diff --git a/web/app/components/base/amplitude/utils.spec.ts b/web/app/components/base/amplitude/__tests__/utils.spec.ts similarity index 96% rename from web/app/components/base/amplitude/utils.spec.ts rename to web/app/components/base/amplitude/__tests__/utils.spec.ts index c69fc93aa4..f1ff5db1e3 100644 --- a/web/app/components/base/amplitude/utils.spec.ts +++ b/web/app/components/base/amplitude/__tests__/utils.spec.ts @@ -1,4 +1,4 @@ -import { resetUser, setUserId, setUserProperties, trackEvent } from './utils' +import { resetUser, setUserId, setUserProperties, trackEvent } from '../utils' const mockState = vi.hoisted(() => ({ enabled: true, @@ -20,8 +20,10 @@ const MockIdentify = vi.hoisted(() => }, ) -vi.mock('./AmplitudeProvider', () => ({ - isAmplitudeEnabled: () => mockState.enabled, +vi.mock('@/config', () => ({ + get isAmplitudeEnabled() { + return mockState.enabled + }, })) vi.mock('@amplitude/analytics-browser', () => ({ diff --git a/web/app/components/base/amplitude/index.spec.ts b/web/app/components/base/amplitude/index.spec.ts deleted file mode 100644 index 919c0b68d1..0000000000 --- a/web/app/components/base/amplitude/index.spec.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { describe, expect, it } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider' -import indexDefault, { - isAmplitudeEnabled as indexIsAmplitudeEnabled, - resetUser, - setUserId, - setUserProperties, - trackEvent, -} from './index' -import { - resetUser as utilsResetUser, - setUserId as utilsSetUserId, - setUserProperties as utilsSetUserProperties, - trackEvent as utilsTrackEvent, -} from './utils' - -describe('Amplitude index exports', () => { - it('exports AmplitudeProvider as default', () => { - expect(indexDefault).toBe(AmplitudeProvider) - }) - - it('exports isAmplitudeEnabled', () => { - expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) - }) - - it('exports utils', () => { - expect(resetUser).toBe(utilsResetUser) - expect(setUserId).toBe(utilsSetUserId) - expect(setUserProperties).toBe(utilsSetUserProperties) - expect(trackEvent).toBe(utilsTrackEvent) - }) -}) diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts index acc792339e..44cbf728e2 100644 --- a/web/app/components/base/amplitude/index.ts +++ b/web/app/components/base/amplitude/index.ts @@ -1,2 +1,2 @@ -export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { default } from './lazy-amplitude-provider' export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/lazy-amplitude-provider.tsx b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx new file mode 100644 index 0000000000..5dfa0e7b53 --- /dev/null +++ b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx @@ -0,0 +1,11 @@ +'use client' + +import type { FC } from 'react' +import type { IAmplitudeProps } from './AmplitudeProvider' +import dynamic from '@/next/dynamic' + +const AmplitudeProvider = dynamic(() => import('./AmplitudeProvider'), { ssr: false }) + +const LazyAmplitudeProvider: FC = props => + +export default LazyAmplitudeProvider diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts index 57b96243ec..8faa8e852e 100644 --- a/web/app/components/base/amplitude/utils.ts +++ b/web/app/components/base/amplitude/utils.ts @@ -1,5 +1,5 @@ import * as amplitude from '@amplitude/analytics-browser' -import { isAmplitudeEnabled } from './AmplitudeProvider' +import { isAmplitudeEnabled } from '@/config' /** * Track custom event @@ -7,7 +7,7 @@ import { isAmplitudeEnabled } from './AmplitudeProvider' * @param eventProperties Event properties (optional) */ export const trackEvent = (eventName: string, eventProperties?: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.track(eventName, eventProperties) } @@ -17,7 +17,7 @@ export const trackEvent = (eventName: string, eventProperties?: Record { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.setUserId(userId) } @@ -27,7 +27,7 @@ export const setUserId = (userId: string) => { * @param properties User properties */ export const setUserProperties = (properties: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return const identifyEvent = new amplitude.Identify() Object.entries(properties).forEach(([key, value]) => { @@ -40,7 +40,7 @@ export const setUserProperties = (properties: Record) => { * Reset user (e.g., when user logs out) */ export const resetUser = () => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.reset() } diff --git a/web/app/components/base/audio-btn/__tests__/index.spec.tsx b/web/app/components/base/audio-btn/__tests__/index.spec.tsx index c8d8ee851b..8f6c26d12b 100644 --- a/web/app/components/base/audio-btn/__tests__/index.spec.tsx +++ b/web/app/components/base/audio-btn/__tests__/index.spec.tsx @@ -1,14 +1,14 @@ import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import i18next from 'i18next' -import { useParams, usePathname } from 'next/navigation' +import { useParams, usePathname } from '@/next/navigation' import AudioBtn from '../index' const mockPlayAudio = vi.fn() const mockPauseAudio = vi.fn() const mockGetAudioPlayer = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: vi.fn(), usePathname: vi.fn(), })) diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 8bea3193c8..47fefe19e5 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -1,10 +1,10 @@ 'use client' import { t } from 'i18next' -import { useParams, usePathname } from 'next/navigation' import { useState } from 'react' import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' import Loading from '@/app/components/base/loading' import Tooltip from '@/app/components/base/tooltip' +import { useParams, usePathname } from '@/next/navigation' import s from './style.module.css' type AudioBtnProps = { diff --git a/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx b/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx index 60a5da5d49..bd5f01bcda 100644 --- a/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx @@ -25,7 +25,7 @@ vi.mock('../context', () => ({ useChatWithHistoryContext: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: vi.fn(), replace: vi.fn(), diff --git a/web/app/components/base/chat/chat-with-history/__tests__/header-in-mobile.spec.tsx b/web/app/components/base/chat/chat-with-history/__tests__/header-in-mobile.spec.tsx index 84bf9134d6..d75f9897a7 100644 --- a/web/app/components/base/chat/chat-with-history/__tests__/header-in-mobile.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/__tests__/header-in-mobile.spec.tsx @@ -22,7 +22,7 @@ vi.mock('../context', () => ({ ChatWithHistoryContext: { Provider: ({ children }: { children: React.ReactNode }) =>
    {children}
    }, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: vi.fn(), replace: vi.fn(), diff --git a/web/app/components/base/chat/chat-with-history/__tests__/index.spec.tsx b/web/app/components/base/chat/chat-with-history/__tests__/index.spec.tsx index 167cc7b385..e306569140 100644 --- a/web/app/components/base/chat/chat-with-history/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/__tests__/index.spec.tsx @@ -26,7 +26,7 @@ vi.mock('@/hooks/use-document-title', () => ({ default: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: vi.fn(), replace: vi.fn(), diff --git a/web/app/components/base/chat/chat-with-history/sidebar/__tests__/index.spec.tsx b/web/app/components/base/chat/chat-with-history/sidebar/__tests__/index.spec.tsx index 896161f66c..bb62869f21 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/__tests__/index.spec.tsx @@ -87,7 +87,7 @@ vi.mock('@/context/global-public-context', () => ({ })) // Mock next/navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx index da989d8b7c..92fa9ea42e 100644 --- a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx @@ -1,8 +1,8 @@ import type { ChatConfig, ChatItemInTree } from '../../types' import type { FileEntity } from '@/app/components/base/file-uploader/types' import { act, renderHook } from '@testing-library/react' -import { useParams, usePathname } from 'next/navigation' import { WorkflowRunningStatus } from '@/app/components/workflow/types' +import { useParams, usePathname } from '@/next/navigation' import { sseGet, ssePost } from '@/service/base' import { useChat } from '../hooks' @@ -28,7 +28,7 @@ vi.mock('@/hooks/use-timestamp', () => ({ default: () => ({ formatTime: vi.fn().mockReturnValue('10:00 AM') }), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: vi.fn(() => ({})), usePathname: vi.fn(() => ''), useRouter: vi.fn(() => ({})), @@ -141,6 +141,145 @@ describe('useChat', () => { expect(result.current.chatList[0].suggestedQuestions).toEqual(['Ask Bob']) }) + describe('opening statement referential stability', () => { + it('should keep the same item reference across multiple streaming chatTree mutations', () => { + let callbacks: HookCallbacks + + vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => { + callbacks = options as HookCallbacks + }) + + const config = { + opening_statement: 'Welcome!', + suggested_questions: ['Q1', 'Q2'], + } + const { result } = renderHook(() => useChat(config as ChatConfig)) + + const openerInitial = result.current.chatList[0] + expect(openerInitial.isOpeningStatement).toBe(true) + expect(openerInitial.content).toBe('Welcome!') + + act(() => { + result.current.handleSend('url', { query: 'hello' }, {}) + }) + + act(() => { + callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' }) + }) + expect(result.current.chatList[0]).toBe(openerInitial) + + act(() => { + callbacks.onData('chunk-1 ', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' }) + }) + expect(result.current.chatList.length).toBeGreaterThan(1) + expect(result.current.chatList[0]).toBe(openerInitial) + + act(() => { + callbacks.onData('chunk-2 ', false, { messageId: 'm-1' }) + }) + expect(result.current.chatList[0]).toBe(openerInitial) + + act(() => { + callbacks.onData('chunk-3', false, { messageId: 'm-1' }) + callbacks.onMessageEnd({ metadata: { retriever_resources: [] } }) + callbacks.onWorkflowFinished({ data: { status: 'succeeded' } }) + callbacks.onCompleted() + }) + expect(result.current.chatList[0]).toBe(openerInitial) + expect(result.current.chatList.at(-1)!.content).toBe('chunk-1 chunk-2 chunk-3') + }) + + it('should keep stable reference when getIntroduction identity changes but output is identical', () => { + const config = { + opening_statement: 'Hello {{name}}', + suggested_questions: ['Ask about {{name}}'], + } + + const { result, rerender } = renderHook( + ({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings), + { initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } }, + ) + + const openerBefore = result.current.chatList[0] + expect(openerBefore.content).toBe('Hello Alice') + expect(openerBefore.suggestedQuestions).toEqual(['Ask about Alice']) + + rerender({ fs: { inputs: { name: 'Alice' }, inputsForm: [] } }) + + expect(result.current.chatList[0]).toBe(openerBefore) + }) + + it('should produce a new item when the processed content actually changes', () => { + const config = { + opening_statement: 'Hello {{name}}', + suggested_questions: ['Ask {{name}}'], + } + + const { result, rerender } = renderHook( + ({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings), + { initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } }, + ) + + const before = result.current.chatList[0] + + rerender({ fs: { inputs: { name: 'Bob' }, inputsForm: [] } }) + + const after = result.current.chatList[0] + expect(after).not.toBe(before) + expect(after.content).toBe('Hello Bob') + expect(after.suggestedQuestions).toEqual(['Ask Bob']) + }) + + it('should keep content and suggestedQuestions stable for opener already in prevChatTree even when sibling metadata changes', () => { + let callbacks: HookCallbacks + vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => { + callbacks = options as HookCallbacks + }) + + const config = { + opening_statement: 'Hello updated', + suggested_questions: ['S1'], + } + const prevChatTree = [{ + id: 'opening-statement', + content: 'old', + isAnswer: true, + isOpeningStatement: true, + suggestedQuestions: [], + }] + + const { result } = renderHook(() => + useChat(config as ChatConfig, undefined, prevChatTree as ChatItemInTree[]), + ) + + const openerBefore = result.current.chatList[0] + expect(openerBefore.content).toBe('Hello updated') + expect(openerBefore.suggestedQuestions).toEqual(['S1']) + + const contentBefore = openerBefore.content + const suggestionsBefore = openerBefore.suggestedQuestions + + act(() => { + result.current.handleSend('url', { query: 'msg' }, {}) + }) + act(() => { + callbacks.onData('resp', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' }) + }) + + expect(result.current.chatList.length).toBeGreaterThan(1) + const openerAfter = result.current.chatList[0] + expect(openerAfter.content).toBe(contentBefore) + expect(openerAfter.suggestedQuestions).toBe(suggestionsBefore) + }) + + it('should use a stable id of "opening-statement"', () => { + const { result } = renderHook(() => + useChat({ opening_statement: 'Hi' } as ChatConfig), + ) + expect(result.current.chatList[0].id).toBe('opening-statement') + }) + }) + describe('handleSend', () => { it('should block send if already responding', async () => { const { result } = renderHook(() => useChat()) diff --git a/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx b/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx index baff417669..836397a586 100644 --- a/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx +++ b/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx @@ -111,7 +111,7 @@ vi.mock('@/app/components/base/chat/chat/log', () => ({ default: () => , })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: vi.fn(() => ({ appId: 'test-app' })), usePathname: vi.fn(() => '/apps/test-app'), })) diff --git a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx index cb1d0f2a55..f628b7de82 100644 --- a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx @@ -208,7 +208,7 @@ vi.mock('../../check-input-forms-hooks', () => ({ // --------------------------------------------------------------------------- // Next.js navigation // --------------------------------------------------------------------------- -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ token: 'test-token' }), useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', diff --git a/web/app/components/base/chat/chat/citation/popup.tsx b/web/app/components/base/chat/chat/citation/popup.tsx index 7dc2baeb88..3a1d4bf251 100644 --- a/web/app/components/base/chat/chat/citation/popup.tsx +++ b/web/app/components/base/chat/chat/citation/popup.tsx @@ -1,6 +1,5 @@ import type { FC, MouseEvent } from 'react' import type { Resources } from './index' -import Link from 'next/link' import { Fragment, useState } from 'react' import { useTranslation } from 'react-i18next' import FileIcon from '@/app/components/base/file-icon' @@ -9,6 +8,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' +import Link from '@/next/link' import { useDocumentDownload } from '@/service/knowledge/use-document' import { downloadUrl } from '@/utils/download' import ProgressTooltip from './progress-tooltip' diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 307fd52443..a0f335f567 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -15,7 +15,6 @@ import type { import { uniqBy } from 'es-toolkit/compat' import { noop } from 'es-toolkit/function' import { produce, setAutoFreeze } from 'immer' -import { useParams, usePathname } from 'next/navigation' import { useCallback, useEffect, @@ -33,6 +32,7 @@ import { import { useToastContext } from '@/app/components/base/toast/context' import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types' import useTimestamp from '@/hooks/use-timestamp' +import { useParams, usePathname } from '@/next/navigation' import { sseGet, ssePost, @@ -88,30 +88,54 @@ export const useChat = ( return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || []) }, [formSettings?.inputs, formSettings?.inputsForm]) + const processedOpeningContent = config?.opening_statement + ? getIntroduction(config.opening_statement) + : undefined + const processedSuggestionsKey = config?.suggested_questions + ? JSON.stringify(config.suggested_questions.map(q => getIntroduction(q))) + : undefined + + const openingStatementItem = useMemo(() => { + if (!processedOpeningContent) + return null + return { + id: 'opening-statement', + content: processedOpeningContent, + isAnswer: true, + isOpeningStatement: true, + suggestedQuestions: processedSuggestionsKey + ? JSON.parse(processedSuggestionsKey) as string[] + : undefined, + } + }, [processedOpeningContent, processedSuggestionsKey]) + + const threadOpener = useMemo( + () => threadMessages.find(item => item.isOpeningStatement) ?? null, + [threadMessages], + ) + + const mergedOpeningItem = useMemo(() => { + if (!threadOpener || !openingStatementItem) + return null + return { + ...threadOpener, + content: openingStatementItem.content, + suggestedQuestions: openingStatementItem.suggestedQuestions, + } + }, [threadOpener, openingStatementItem]) + /** Final chat list that will be rendered */ const chatList = useMemo(() => { const ret = [...threadMessages] - if (config?.opening_statement) { + if (openingStatementItem) { const index = threadMessages.findIndex(item => item.isOpeningStatement) - if (index > -1) { - ret[index] = { - ...ret[index], - content: getIntroduction(config.opening_statement), - suggestedQuestions: config.suggested_questions?.map(item => getIntroduction(item)), - } - } - else { - ret.unshift({ - id: 'opening-statement', - content: getIntroduction(config.opening_statement), - isAnswer: true, - isOpeningStatement: true, - suggestedQuestions: config.suggested_questions?.map(item => getIntroduction(item)), - }) - } + if (index > -1 && mergedOpeningItem) + ret[index] = mergedOpeningItem + else if (index === -1) + ret.unshift(openingStatementItem) } return ret - }, [threadMessages, config, getIntroduction]) + }, [threadMessages, openingStatementItem, mergedOpeningItem]) useEffect(() => { setAutoFreeze(false) diff --git a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx index aad2d3d09b..689a9e0439 100644 --- a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx @@ -9,7 +9,7 @@ vi.mock('../../context', () => ({ useEmbeddedChatbotContext: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ token: 'test-token' }), useRouter: () => ({ push: vi.fn() }), usePathname: () => '/', diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index b47fec1d0a..5881f565a4 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -158,7 +158,7 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { rootNodes.push(questionNode) } else { - map[parentMessageId]?.children!.push(questionNode) + map[parentMessageId].children!.push(questionNode) } } } diff --git a/web/app/components/base/confirm/index.tsx b/web/app/components/base/confirm/index.tsx index 27b67ea507..91d9e7bfb8 100644 --- a/web/app/components/base/confirm/index.tsx +++ b/web/app/components/base/confirm/index.tsx @@ -26,6 +26,11 @@ export type IConfirm = { showConfirm?: boolean showCancel?: boolean maskClosable?: boolean + confirmInputLabel?: string + confirmInputPlaceholder?: string + confirmInputValue?: string + onConfirmInputChange?: (value: string) => void + confirmInputMatchValue?: string } function Confirm({ @@ -42,6 +47,11 @@ function Confirm({ isLoading = false, isDisabled = false, maskClosable = true, + confirmInputLabel, + confirmInputPlaceholder, + confirmInputValue = '', + onConfirmInputChange, + confirmInputMatchValue, }: IConfirm) { const { t } = useTranslation() const dialogRef = useRef(null) @@ -51,12 +61,13 @@ function Confirm({ const confirmTxt = confirmText || `${t('operation.confirm', { ns: 'common' })}` const cancelTxt = cancelText || `${t('operation.cancel', { ns: 'common' })}` + const isConfirmDisabled = isDisabled || (confirmInputMatchValue ? confirmInputValue !== confirmInputMatchValue : false) useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if (event.key === 'Escape') onCancel() - if (event.key === 'Enter' && isShow) { + if (event.key === 'Enter' && isShow && !isConfirmDisabled) { event.preventDefault() onConfirm() } @@ -66,7 +77,7 @@ function Confirm({ return () => { document.removeEventListener('keydown', handleKeyDown) } - }, [onCancel, onConfirm, isShow]) + }, [onCancel, onConfirm, isShow, isConfirmDisabled]) const handleClickOutside = (event: MouseEvent) => { if (maskClosable && dialogRef.current && !dialogRef.current.contains(event.target as Node)) @@ -123,11 +134,25 @@ function Confirm({ {title}
    -
    {content}
    +
    {content}
    + {confirmInputLabel && ( +
    + + onConfirmInputChange?.(e.target.value)} + /> +
    + )}
    {showCancel && } - {showConfirm && } + {showConfirm && }
    diff --git a/web/app/components/base/encrypted-bottom/index.tsx b/web/app/components/base/encrypted-bottom/index.tsx index 5a9bc9b488..5f35433612 100644 --- a/web/app/components/base/encrypted-bottom/index.tsx +++ b/web/app/components/base/encrypted-bottom/index.tsx @@ -1,7 +1,7 @@ import type { I18nKeysWithPrefix } from '@/types/i18n' import { RiLock2Fill } from '@remixicon/react' -import Link from 'next/link' import { useTranslation } from 'react-i18next' +import Link from '@/next/link' import { cn } from '@/utils/classnames' type EncryptedKey = I18nKeysWithPrefix<'common', 'provider.encrypted.'> diff --git a/web/app/components/base/features/new-feature-panel/__tests__/index.spec.tsx b/web/app/components/base/features/new-feature-panel/__tests__/index.spec.tsx index 20632c4954..77f9a0253b 100644 --- a/web/app/components/base/features/new-feature-panel/__tests__/index.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { render, screen } from '@testing-library/react' import { FeaturesProvider } from '../../context' import NewFeaturePanel from '../index' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/app/test-app-id/configuration', })) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/__tests__/index.spec.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/__tests__/index.spec.tsx index f2ddc5482d..03ddbc6322 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/__tests__/index.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ import AnnotationReply from '../index' const originalConsoleError = console.error const mockPush = vi.fn() let mockPathname = '/app/test-app-id/configuration' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush }), usePathname: () => mockPathname, })) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx index df8982407c..1ad4ef613e 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx @@ -2,7 +2,6 @@ import type { OnFeaturesChange } from '@/app/components/base/features/types' import type { AnnotationReplyConfig } from '@/models/debug' import { RiEqualizer2Line, RiExternalLinkLine } from '@remixicon/react' import { produce } from 'immer' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -14,6 +13,7 @@ import FeatureCard from '@/app/components/base/features/new-feature-panel/featur import { MessageFast } from '@/app/components/base/icons/src/vender/features' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { ANNOTATION_DEFAULT } from '@/config' +import { usePathname, useRouter } from '@/next/navigation' type Props = { disabled?: boolean diff --git a/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/param-config-content.spec.tsx b/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/param-config-content.spec.tsx index 66d870f28f..535d40e00a 100644 --- a/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/param-config-content.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/param-config-content.spec.tsx @@ -22,7 +22,7 @@ const mockUseAppVoices = vi.fn((_appId: string, _language?: string) => ({ data: mockVoiceItems, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, useParams: () => ({}), })) diff --git a/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/voice-settings.spec.tsx b/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/voice-settings.spec.tsx index 658d5f500b..f77802c133 100644 --- a/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/voice-settings.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/text-to-speech/__tests__/voice-settings.spec.tsx @@ -35,7 +35,7 @@ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) =>
    {children}
    , })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => '/app/test-app-id/configuration', useParams: () => ({ appId: 'test-app-id' }), })) diff --git a/web/app/components/base/features/new-feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/new-feature-panel/text-to-speech/param-config-content.tsx index 11db9346ff..d4e008c4e6 100644 --- a/web/app/components/base/features/new-feature-panel/text-to-speech/param-config-content.tsx +++ b/web/app/components/base/features/new-feature-panel/text-to-speech/param-config-content.tsx @@ -3,7 +3,6 @@ import type { OnFeaturesChange } from '@/app/components/base/features/types' import type { Item } from '@/app/components/base/select' import { Listbox, ListboxButton, ListboxOption, ListboxOptions, Transition } from '@headlessui/react' import { produce } from 'immer' -import { usePathname } from 'next/navigation' import * as React from 'react' import { Fragment } from 'react' import { useTranslation } from 'react-i18next' @@ -13,6 +12,7 @@ import { useFeatures, useFeaturesStore } from '@/app/components/base/features/ho import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import { languages } from '@/i18n-config/language' +import { usePathname } from '@/next/navigation' import { useAppVoices } from '@/service/use-apps' import { TtsAutoPlay } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/base/file-uploader/dynamic-pdf-preview.spec.tsx b/web/app/components/base/file-uploader/__tests__/dynamic-pdf-preview.spec.tsx similarity index 93% rename from web/app/components/base/file-uploader/dynamic-pdf-preview.spec.tsx rename to web/app/components/base/file-uploader/__tests__/dynamic-pdf-preview.spec.tsx index 1f15c419eb..868f153dbc 100644 --- a/web/app/components/base/file-uploader/dynamic-pdf-preview.spec.tsx +++ b/web/app/components/base/file-uploader/__tests__/dynamic-pdf-preview.spec.tsx @@ -1,5 +1,5 @@ import { fireEvent, render, screen } from '@testing-library/react' -import DynamicPdfPreview from './dynamic-pdf-preview' +import DynamicPdfPreview from '../dynamic-pdf-preview' type DynamicPdfPreviewProps = { url: string @@ -40,11 +40,11 @@ const mockPdfPreview = vi.hoisted(() => vi.fn(() => null), ) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: mockDynamic, })) -vi.mock('./pdf-preview', () => ({ +vi.mock('../pdf-preview', () => ({ default: mockPdfPreview, })) @@ -78,7 +78,7 @@ describe('dynamic-pdf-preview', () => { expect(loaded).toBeInstanceOf(Promise) const loadedModule = (await loaded) as { default: unknown } - const pdfPreviewModule = await import('./pdf-preview') + const pdfPreviewModule = await import('../pdf-preview') expect(loadedModule.default).toBe(pdfPreviewModule.default) }) diff --git a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts index 8343974967..824a3b7a03 100644 --- a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts @@ -6,7 +6,7 @@ import { useFile, useFileSizeLimit } from '../hooks' const mockNotify = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ token: undefined }), })) diff --git a/web/app/components/base/file-uploader/dynamic-pdf-preview.tsx b/web/app/components/base/file-uploader/dynamic-pdf-preview.tsx index 116db89864..225d5664c2 100644 --- a/web/app/components/base/file-uploader/dynamic-pdf-preview.tsx +++ b/web/app/components/base/file-uploader/dynamic-pdf-preview.tsx @@ -1,6 +1,6 @@ 'use client' -import dynamic from 'next/dynamic' +import dynamic from '@/next/dynamic' type DynamicPdfPreviewProps = { url: string diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 4aab60175c..27345b22ff 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -4,7 +4,6 @@ import type { FileUpload } from '@/app/components/base/features/types' import type { FileUploadConfigResponse } from '@/models/common' import { noop } from 'es-toolkit/function' import { produce } from 'immer' -import { useParams } from 'next/navigation' import { useCallback, useState, @@ -20,6 +19,7 @@ import { } from '@/app/components/base/file-uploader/constants' import { useToastContext } from '@/app/components/base/toast/context' import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { useParams } from '@/next/navigation' import { uploadRemoteFileInfo } from '@/service/common' import { TransferMethod } from '@/types/app' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/base/form/components/field/__tests__/file-uploader.spec.tsx b/web/app/components/base/form/components/field/__tests__/file-uploader.spec.tsx index dee7c97222..bff8e9cbf9 100644 --- a/web/app/components/base/form/components/field/__tests__/file-uploader.spec.tsx +++ b/web/app/components/base/form/components/field/__tests__/file-uploader.spec.tsx @@ -27,7 +27,7 @@ vi.mock('../../..', () => ({ useFieldContext: () => mockField, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ token: 'test-token' }), })) diff --git a/web/app/components/base/form/form-scenarios/base/__tests__/field.spec.tsx b/web/app/components/base/form/form-scenarios/base/__tests__/field.spec.tsx index 1d7734f670..81190dc277 100644 --- a/web/app/components/base/form/form-scenarios/base/__tests__/field.spec.tsx +++ b/web/app/components/base/form/form-scenarios/base/__tests__/field.spec.tsx @@ -6,7 +6,7 @@ import { useAppForm } from '../../..' import BaseField from '../field' import { BaseFieldType } from '../types' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({}), })) diff --git a/web/app/components/base/ga/__tests__/index.spec.tsx b/web/app/components/base/ga/__tests__/index.spec.tsx index ee7f7a2a9d..619c4514dc 100644 --- a/web/app/components/base/ga/__tests__/index.spec.tsx +++ b/web/app/components/base/ga/__tests__/index.spec.tsx @@ -31,11 +31,11 @@ vi.mock('@/config', () => ({ }, })) -vi.mock('next/headers', () => ({ +vi.mock('@/next/headers', () => ({ headers: mockHeaders, })) -vi.mock('next/script', () => ({ +vi.mock('@/next/script', () => ({ default: ({ id, strategy, diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index 7225dcf428..3e19afd974 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -1,8 +1,8 @@ import type { FC } from 'react' -import { headers } from 'next/headers' -import Script from 'next/script' import * as React from 'react' import { IS_CE_EDITION, IS_PROD } from '@/config' +import { headers } from '@/next/headers' +import Script from '@/next/script' export enum GaType { admin = 'admin', diff --git a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts index f79ea98081..e4295dfb09 100644 --- a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts @@ -9,7 +9,7 @@ vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ token: undefined }), })) diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index 03cf0feeca..9251d3888f 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -1,9 +1,9 @@ import type { ClipboardEvent } from 'react' import type { ImageFile, VisionSettings } from '@/types/app' -import { useParams } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useToastContext } from '@/app/components/base/toast/context' +import { useParams } from '@/next/navigation' import { ALLOW_FILE_EXTENSIONS, TransferMethod } from '@/types/app' import { getImageUploadErrorMessage, imageUpload } from './utils' diff --git a/web/app/components/base/linked-apps-panel/__tests__/index.spec.tsx b/web/app/components/base/linked-apps-panel/__tests__/index.spec.tsx index 27408531c4..5576fb289e 100644 --- a/web/app/components/base/linked-apps-panel/__tests__/index.spec.tsx +++ b/web/app/components/base/linked-apps-panel/__tests__/index.spec.tsx @@ -4,7 +4,7 @@ import { vi } from 'vitest' import { AppModeEnum } from '@/types/app' import LinkedAppsPanel from '../index' -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, className }: { children: React.ReactNode, href: string, className: string }) => ( {children} diff --git a/web/app/components/base/linked-apps-panel/index.tsx b/web/app/components/base/linked-apps-panel/index.tsx index adc8ccf729..1ce76e0647 100644 --- a/web/app/components/base/linked-apps-panel/index.tsx +++ b/web/app/components/base/linked-apps-panel/index.tsx @@ -2,9 +2,9 @@ import type { FC } from 'react' import type { RelatedApp } from '@/models/datasets' import { RiArrowRightUpLine } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import AppIcon from '@/app/components/base/app-icon' +import Link from '@/next/link' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx index 308232fd0f..745b7657d7 100644 --- a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx @@ -21,6 +21,8 @@ let clientWidthSpy: { mockRestore: () => void } | null = null let clientHeightSpy: { mockRestore: () => void } | null = null let offsetWidthSpy: { mockRestore: () => void } | null = null let offsetHeightSpy: { mockRestore: () => void } | null = null +let consoleErrorSpy: ReturnType | null = null +let consoleWarnSpy: ReturnType | null = null type AudioContextCtor = new () => unknown type WindowWithLegacyAudio = Window & { @@ -83,6 +85,8 @@ describe('CodeBlock', () => { beforeEach(() => { vi.clearAllMocks() mockUseTheme.mockReturnValue({ theme: Theme.light }) + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) clientWidthSpy = vi.spyOn(HTMLElement.prototype, 'clientWidth', 'get').mockReturnValue(900) clientHeightSpy = vi.spyOn(HTMLElement.prototype, 'clientHeight', 'get').mockReturnValue(400) offsetWidthSpy = vi.spyOn(HTMLElement.prototype, 'offsetWidth', 'get').mockReturnValue(900) @@ -98,6 +102,10 @@ describe('CodeBlock', () => { afterEach(() => { vi.useRealTimers() + consoleErrorSpy?.mockRestore() + consoleWarnSpy?.mockRestore() + consoleErrorSpy = null + consoleWarnSpy = null clientWidthSpy?.mockRestore() clientHeightSpy?.mockRestore() offsetWidthSpy?.mockRestore() diff --git a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx index e8b956cbbf..4f22468157 100644 --- a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx @@ -163,25 +163,16 @@ describe('ThinkBlock', () => { expect(screen.getByText(/Thought/)).toBeInTheDocument() }) - it('should NOT stop timer when isResponding is undefined (outside ChatContextProvider)', () => { - // Render without ChatContextProvider + it('should stop timer when isResponding is undefined (historical conversation outside active response)', () => { + // Render without ChatContextProvider — simulates historical conversation render(

    Content without ENDTHINKFLAG

    , ) - // Initial state should show "Thinking..." - expect(screen.getByText(/Thinking\.\.\./)).toBeInTheDocument() - - // Advance timer - act(() => { - vi.advanceTimersByTime(2000) - }) - - // Timer should still be running (showing "Thinking..." not "Thought") - expect(screen.getByText(/Thinking\.\.\./)).toBeInTheDocument() - expect(screen.getByText(/\(2\.0s\)/)).toBeInTheDocument() + // Timer should be stopped immediately — isResponding undefined means not in active response + expect(screen.getByText(/Thought/)).toBeInTheDocument() }) }) diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index 837929cfff..412c61d52d 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -1,5 +1,4 @@ import ReactEcharts from 'echarts-for-react' -import dynamic from 'next/dynamic' import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react' import SyntaxHighlighter from 'react-syntax-highlighter' import { @@ -12,6 +11,7 @@ import MarkdownMusic from '@/app/components/base/markdown-blocks/music' import ErrorBoundary from '@/app/components/base/markdown/error-boundary' import SVGBtn from '@/app/components/base/svg' import useTheme from '@/hooks/use-theme' +import dynamic from '@/next/dynamic' import { Theme } from '@/types/app' import SVGRenderer from '../svg-gallery' // Assumes svg-gallery.tsx is in /base directory @@ -85,13 +85,30 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any const processedRef = useRef(false) // Track if content was successfully processed const isInitialRenderRef = useRef(true) // Track if this is initial render const chartInstanceRef = useRef(null) // Direct reference to ECharts instance - const resizeTimerRef = useRef(null) // For debounce handling + const resizeTimerRef = useRef | null>(null) // For debounce handling + const chartReadyTimerRef = useRef | null>(null) const finishedEventCountRef = useRef(0) // Track finished event trigger count const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') const isDarkMode = theme === Theme.dark + const clearResizeTimer = useCallback(() => { + if (!resizeTimerRef.current) + return + + clearTimeout(resizeTimerRef.current) + resizeTimerRef.current = null + }, []) + + const clearChartReadyTimer = useCallback(() => { + if (!chartReadyTimerRef.current) + return + + clearTimeout(chartReadyTimerRef.current) + chartReadyTimerRef.current = null + }, []) + const echartsStyle = useMemo(() => ({ height: '350px', width: '100%', @@ -104,26 +121,27 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any // Debounce resize operations const debouncedResize = useCallback(() => { - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() resizeTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() resizeTimerRef.current = null }, 200) - }, []) + }, [clearResizeTimer]) // Handle ECharts instance initialization const handleChartReady = useCallback((instance: any) => { chartInstanceRef.current = instance // Force resize to ensure timeline displays correctly - setTimeout(() => { + clearChartReadyTimer() + chartReadyTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() + chartReadyTimerRef.current = null }, 200) - }, []) + }, [clearChartReadyTimer]) // Store event handlers in useMemo to avoid recreating them const echartsEvents = useMemo(() => ({ @@ -157,10 +175,20 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any return () => { window.removeEventListener('resize', handleResize) - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null } - }, [language, debouncedResize]) + }, [language, debouncedResize, clearResizeTimer, clearChartReadyTimer]) + + useEffect(() => { + return () => { + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null + echartsRef.current = null + } + }, [clearResizeTimer, clearChartReadyTimer]) // Process chart data when content changes useEffect(() => { // Only process echarts content diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index f920218152..184ed89274 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -39,9 +39,10 @@ const removeEndThink = (children: any): any => { const useThinkTimer = (children: any) => { const { isResponding } = useChatContext() + const endThinkDetected = hasEndThink(children) const [startTime] = useState(() => Date.now()) const [elapsedTime, setElapsedTime] = useState(0) - const [isComplete, setIsComplete] = useState(false) + const [isComplete, setIsComplete] = useState(() => endThinkDetected) const timerRef = useRef(null) useEffect(() => { @@ -61,11 +62,10 @@ const useThinkTimer = (children: any) => { useEffect(() => { // Stop timer when: // 1. Content has [ENDTHINKFLAG] marker (normal completion) - // 2. isResponding is explicitly false (user clicked stop button) - // Note: Don't stop when isResponding is undefined (component used outside ChatContextProvider) - if (hasEndThink(children) || isResponding === false) + // 2. isResponding is not true (false = user clicked stop, undefined = historical conversation) + if (endThinkDetected || !isResponding) setIsComplete(true) - }, [children, isResponding]) + }, [endThinkDetected, isResponding]) return { elapsedTime, isComplete } } diff --git a/web/app/components/base/markdown-with-directive/index.spec.tsx b/web/app/components/base/markdown-with-directive/__tests__/index.spec.tsx similarity index 96% rename from web/app/components/base/markdown-with-directive/index.spec.tsx rename to web/app/components/base/markdown-with-directive/__tests__/index.spec.tsx index fc4b813247..e71abd6620 100644 --- a/web/app/components/base/markdown-with-directive/index.spec.tsx +++ b/web/app/components/base/markdown-with-directive/__tests__/index.spec.tsx @@ -1,9 +1,9 @@ import { render, screen } from '@testing-library/react' import DOMPurify from 'dompurify' -import { validateDirectiveProps } from './components/markdown-with-directive-schema' -import WithIconCardItem from './components/with-icon-card-item' -import WithIconCardList from './components/with-icon-card-list' -import { MarkdownWithDirective } from './index' +import { validateDirectiveProps } from '../components/markdown-with-directive-schema' +import WithIconCardItem from '../components/with-icon-card-item' +import WithIconCardList from '../components/with-icon-card-list' +import { MarkdownWithDirective } from '../index' const FOUR_COLON_RE = /:{4}/ diff --git a/web/app/components/base/markdown-with-directive/components/markdown-with-directive-schema.spec.ts b/web/app/components/base/markdown-with-directive/components/__tests__/markdown-with-directive-schema.spec.ts similarity index 97% rename from web/app/components/base/markdown-with-directive/components/markdown-with-directive-schema.spec.ts rename to web/app/components/base/markdown-with-directive/components/__tests__/markdown-with-directive-schema.spec.ts index 9e74ed43b4..c69bdf4987 100644 --- a/web/app/components/base/markdown-with-directive/components/markdown-with-directive-schema.spec.ts +++ b/web/app/components/base/markdown-with-directive/components/__tests__/markdown-with-directive-schema.spec.ts @@ -1,4 +1,4 @@ -import { validateDirectiveProps } from './markdown-with-directive-schema' +import { validateDirectiveProps } from '../markdown-with-directive-schema' describe('markdown-with-directive-schema', () => { beforeEach(() => { diff --git a/web/app/components/base/markdown-with-directive/components/with-icon-card-item.spec.tsx b/web/app/components/base/markdown-with-directive/components/__tests__/with-icon-card-item.spec.tsx similarity index 96% rename from web/app/components/base/markdown-with-directive/components/with-icon-card-item.spec.tsx rename to web/app/components/base/markdown-with-directive/components/__tests__/with-icon-card-item.spec.tsx index dbe293dcf6..8a2d4a552b 100644 --- a/web/app/components/base/markdown-with-directive/components/with-icon-card-item.spec.tsx +++ b/web/app/components/base/markdown-with-directive/components/__tests__/with-icon-card-item.spec.tsx @@ -1,5 +1,5 @@ import { render, screen } from '@testing-library/react' -import WithIconCardItem from './with-icon-card-item' +import WithIconCardItem from '../with-icon-card-item' describe('WithIconCardItem', () => { beforeEach(() => { diff --git a/web/app/components/base/markdown-with-directive/components/with-icon-card-list.spec.tsx b/web/app/components/base/markdown-with-directive/components/__tests__/with-icon-card-list.spec.tsx similarity index 95% rename from web/app/components/base/markdown-with-directive/components/with-icon-card-list.spec.tsx rename to web/app/components/base/markdown-with-directive/components/__tests__/with-icon-card-list.spec.tsx index d5b701b01c..5698b4a921 100644 --- a/web/app/components/base/markdown-with-directive/components/with-icon-card-list.spec.tsx +++ b/web/app/components/base/markdown-with-directive/components/__tests__/with-icon-card-list.spec.tsx @@ -1,5 +1,5 @@ import { render, screen } from '@testing-library/react' -import WithIconCardList from './with-icon-card-list' +import WithIconCardList from '../with-icon-card-list' describe('WithIconCardList', () => { beforeEach(() => { diff --git a/web/app/components/base/markdown/__tests__/index.spec.tsx b/web/app/components/base/markdown/__tests__/index.spec.tsx index 5d0261b074..08c4527003 100644 --- a/web/app/components/base/markdown/__tests__/index.spec.tsx +++ b/web/app/components/base/markdown/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ const { mockReactMarkdownWrapper } = vi.hoisted(() => ({ mockReactMarkdownWrapper: vi.fn(), })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: () => { const MockStreamdownWrapper = (props: { latexContent: string }) => { mockReactMarkdownWrapper(props) diff --git a/web/app/components/base/markdown/index.tsx b/web/app/components/base/markdown/index.tsx index 6faee9c260..5915816d7a 100644 --- a/web/app/components/base/markdown/index.tsx +++ b/web/app/components/base/markdown/index.tsx @@ -1,7 +1,7 @@ import type { SimplePluginInfo, StreamdownWrapperProps } from './streamdown-wrapper' import { flow } from 'es-toolkit/compat' -import dynamic from 'next/dynamic' import { memo, useMemo } from 'react' +import dynamic from '@/next/dynamic' import { cn } from '@/utils/classnames' import { preprocessLaTeX, preprocessThinkTag } from './markdown-utils' diff --git a/web/app/components/base/markdown/streamdown-wrapper.tsx b/web/app/components/base/markdown/streamdown-wrapper.tsx index 6fdf954edc..46db301adb 100644 --- a/web/app/components/base/markdown/streamdown-wrapper.tsx +++ b/web/app/components/base/markdown/streamdown-wrapper.tsx @@ -1,7 +1,6 @@ import type { ComponentType } from 'react' import type { Components, StreamdownProps } from 'streamdown' import { createMathPlugin } from '@streamdown/math' -import dynamic from 'next/dynamic' import { memo, useMemo } from 'react' import RemarkBreaks from 'remark-breaks' import { defaultRehypePlugins, defaultRemarkPlugins, Streamdown } from 'streamdown' @@ -18,6 +17,7 @@ import { VideoBlock, } from '@/app/components/base/markdown-blocks' import { ENABLE_SINGLE_DOLLAR_LATEX } from '@/config' +import dynamic from '@/next/dynamic' import { customUrlTransform } from './markdown-utils' import 'katex/dist/katex.min.css' diff --git a/web/app/components/base/new-audio-button/__tests__/index.spec.tsx b/web/app/components/base/new-audio-button/__tests__/index.spec.tsx index 64dd590012..23696fca74 100644 --- a/web/app/components/base/new-audio-button/__tests__/index.spec.tsx +++ b/web/app/components/base/new-audio-button/__tests__/index.spec.tsx @@ -1,15 +1,15 @@ import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import i18next from 'i18next' -import { useParams, usePathname } from 'next/navigation' import { beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' +import { useParams, usePathname } from '@/next/navigation' import AudioBtn from '../index' const mockPlayAudio = vi.fn() const mockPauseAudio = vi.fn() const mockGetAudioPlayer = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: vi.fn(), usePathname: vi.fn(), })) diff --git a/web/app/components/base/new-audio-button/index.tsx b/web/app/components/base/new-audio-button/index.tsx index 7e1e1ccc78..c6569ff958 100644 --- a/web/app/components/base/new-audio-button/index.tsx +++ b/web/app/components/base/new-audio-button/index.tsx @@ -3,11 +3,11 @@ import { RiVolumeUpLine, } from '@remixicon/react' import { t } from 'i18next' -import { useParams, usePathname } from 'next/navigation' import { useState } from 'react' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' import Tooltip from '@/app/components/base/tooltip' +import { useParams, usePathname } from '@/next/navigation' type AudioBtnProps = { id?: string diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/use-llm-model-plugin-installed.spec.ts b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/use-llm-model-plugin-installed.spec.ts similarity index 96% rename from web/app/components/base/prompt-editor/plugins/workflow-variable-block/use-llm-model-plugin-installed.spec.ts rename to web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/use-llm-model-plugin-installed.spec.ts index f64865317f..fd77302d13 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/use-llm-model-plugin-installed.spec.ts +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/use-llm-model-plugin-installed.spec.ts @@ -1,7 +1,7 @@ import type { WorkflowNodesMap } from '@/app/components/base/prompt-editor/types' import { renderHook } from '@testing-library/react' import { BlockEnum } from '@/app/components/workflow/types' -import { useLlmModelPluginInstalled } from './use-llm-model-plugin-installed' +import { useLlmModelPluginInstalled } from '../use-llm-model-plugin-installed' let mockModelProviders: Array<{ provider: string }> = [] diff --git a/web/app/components/base/tag-management/__tests__/filter.spec.tsx b/web/app/components/base/tag-management/__tests__/filter.spec.tsx index 3cffac29b2..a455d1a791 100644 --- a/web/app/components/base/tag-management/__tests__/filter.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/filter.spec.tsx @@ -14,23 +14,11 @@ vi.mock('@/service/tag', () => ({ fetchTagList, })) -// Mock ahooks to avoid timer-related issues in tests vi.mock('ahooks', () => { return { - useDebounceFn: (fn: (...args: unknown[]) => void) => { - const ref = React.useRef(fn) - ref.current = fn - const stableRun = React.useRef((...args: unknown[]) => { - // Schedule to run after current event handler finishes, - // allowing React to process pending state updates first - Promise.resolve().then(() => ref.current(...args)) - }) - return { run: stableRun.current } - }, useMount: (fn: () => void) => { React.useEffect(() => { fn() - // eslint-disable-next-line react-hooks/exhaustive-deps }, []) }, } @@ -228,7 +216,6 @@ describe('TagFilter', () => { const searchInput = screen.getByRole('textbox') await user.type(searchInput, 'Front') - // With debounce mocked to be synchronous, results should be immediate expect(screen.getByText('Frontend')).toBeInTheDocument() expect(screen.queryByText('Backend')).not.toBeInTheDocument() expect(screen.queryByText('API Design')).not.toBeInTheDocument() @@ -257,22 +244,14 @@ describe('TagFilter', () => { const searchInput = screen.getByRole('textbox') await user.type(searchInput, 'Front') - // Wait for the debounced search to filter - await waitFor(() => { - expect(screen.queryByText('Backend')).not.toBeInTheDocument() - }) + expect(screen.queryByText('Backend')).not.toBeInTheDocument() - // Clear the search using the Input's clear button const clearButton = screen.getByTestId('input-clear') await user.click(clearButton) - // The input value should be cleared expect(searchInput).toHaveValue('') - // After the clear + microtask re-render, all app tags should be visible again - await waitFor(() => { - expect(screen.getByText('Backend')).toBeInTheDocument() - }) + expect(screen.getByText('Backend')).toBeInTheDocument() expect(screen.getByText('Frontend')).toBeInTheDocument() expect(screen.getByText('API Design')).toBeInTheDocument() }) diff --git a/web/app/components/base/tag-management/filter.tsx b/web/app/components/base/tag-management/filter.tsx index ad71334ddb..fcd59bcf7d 100644 --- a/web/app/components/base/tag-management/filter.tsx +++ b/web/app/components/base/tag-management/filter.tsx @@ -1,15 +1,15 @@ import type { FC } from 'react' import type { Tag } from '@/app/components/base/tag-management/constant' -import { useDebounceFn, useMount } from 'ahooks' +import { useMount } from 'ahooks' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { Tag01, Tag03 } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' import Input from '@/app/components/base/input' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' import { fetchTagList } from '@/service/tag' import { cn } from '@/utils/classnames' @@ -33,18 +33,10 @@ const TagFilter: FC = ({ const setShowTagManagementModal = useTagStore(s => s.setShowTagManagementModal) const [keywords, setKeywords] = useState('') - const [searchKeywords, setSearchKeywords] = useState('') - const { run: handleSearch } = useDebounceFn(() => { - setSearchKeywords(keywords) - }, { wait: 500 }) - const handleKeywordsChange = (value: string) => { - setKeywords(value) - handleSearch() - } const filteredTagList = useMemo(() => { - return tagList.filter(tag => tag.type === type && tag.name.includes(searchKeywords)) - }, [type, tagList, searchKeywords]) + return tagList.filter(tag => tag.type === type && tag.name.includes(keywords)) + }, [type, tagList, keywords]) const currentTag = useMemo(() => { return tagList.find(tag => tag.id === value[0]) @@ -64,61 +56,61 @@ const TagFilter: FC = ({ }) return ( -
    - setOpen(v => !v)} - className="block" - > -
    -
    - -
    -
    - {!value.length && t('tag.placeholder', { ns: 'common' })} - {!!value.length && currentTag?.name} -
    - {value.length > 1 && ( -
    {`+${value.length - 1}`}
    - )} - {!value.length && ( +
    - +
    - )} - {!!value.length && ( -
    { - e.stopPropagation() - onChange([]) - }} - data-testid="tag-filter-clear-button" - > - +
    + {!value.length && t('tag.placeholder', { ns: 'common' })} + {!!value.length && currentTag?.name}
    - )} -
    - - -
    + {value.length > 1 && ( +
    {`+${value.length - 1}`}
    + )} + {!value.length && ( +
    + +
    + )} + + )} + /> + {!!value.length && ( + + )} + +
    handleKeywordsChange(e.target.value)} - onClear={() => handleKeywordsChange('')} + onChange={e => setKeywords(e.target.value)} + onClear={() => setKeywords('')} />
    @@ -155,9 +147,9 @@ const TagFilter: FC = ({
    -
    +
    - + ) } diff --git a/web/app/components/base/toast/context.ts b/web/app/components/base/toast/context.ts index ddd8f91336..07b4e72602 100644 --- a/web/app/components/base/toast/context.ts +++ b/web/app/components/base/toast/context.ts @@ -1,8 +1,15 @@ 'use client' +/** + * @deprecated Use `@/app/components/base/ui/toast` instead. + * This module will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32811 + */ + import type { ReactNode } from 'react' import { createContext, useContext } from 'use-context-selector' +/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */ export type IToastProps = { type?: 'success' | 'error' | 'warning' | 'info' size?: 'md' | 'sm' @@ -19,5 +26,8 @@ type IToastContext = { close: () => void } +/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */ export const ToastContext = createContext({} as IToastContext) + +/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */ export const useToastContext = () => useContext(ToastContext) diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index 897b6039ba..0cb14f3f11 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -1,4 +1,11 @@ 'use client' + +/** + * @deprecated Use `@/app/components/base/ui/toast` instead. + * This component will be removed after migration is complete. + * See: https://github.com/langgenius/dify/issues/32811 + */ + import type { ReactNode } from 'react' import type { IToastProps } from './context' import { noop } from 'es-toolkit/function' @@ -12,6 +19,7 @@ import { ToastContext, useToastContext } from './context' export type ToastHandle = { clear?: VoidFunction } + const Toast = ({ type = 'info', size = 'md', @@ -74,6 +82,7 @@ const Toast = ({ ) } +/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */ export const ToastProvider = ({ children, }: { diff --git a/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx b/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx index c5fb532d98..b6772e5ad0 100644 --- a/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/dropdown-menu/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ import type { ComponentPropsWithoutRef, ReactNode } from 'react' import { fireEvent, render, screen, within } from '@testing-library/react' -import Link from 'next/link' import { describe, expect, it, vi } from 'vitest' +import Link from '@/next/link' import { DropdownMenu, DropdownMenuContent, @@ -14,7 +14,7 @@ import { DropdownMenuTrigger, } from '../index' -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ href, children, diff --git a/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx b/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx new file mode 100644 index 0000000000..b4524a971e --- /dev/null +++ b/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx @@ -0,0 +1,296 @@ +import { render, screen, waitFor } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { + ScrollArea, + ScrollAreaContent, + ScrollAreaCorner, + ScrollAreaRoot, + ScrollAreaScrollbar, + ScrollAreaThumb, + ScrollAreaViewport, +} from '../index' +import styles from '../index.module.css' + +const renderScrollArea = (options: { + rootClassName?: string + viewportClassName?: string + verticalScrollbarClassName?: string + horizontalScrollbarClassName?: string + verticalThumbClassName?: string + horizontalThumbClassName?: string +} = {}) => { + return render( + + + +
    Scrollable content
    +
    +
    + + + + + + +
    , + ) +} + +describe('scroll-area wrapper', () => { + describe('Rendering', () => { + it('should render the compound exports together', async () => { + renderScrollArea() + + await waitFor(() => { + expect(screen.getByTestId('scroll-area-root')).toBeInTheDocument() + expect(screen.getByTestId('scroll-area-viewport')).toBeInTheDocument() + expect(screen.getByTestId('scroll-area-content')).toHaveTextContent('Scrollable content') + expect(screen.getByTestId('scroll-area-vertical-scrollbar')).toBeInTheDocument() + expect(screen.getByTestId('scroll-area-vertical-thumb')).toBeInTheDocument() + expect(screen.getByTestId('scroll-area-horizontal-scrollbar')).toBeInTheDocument() + expect(screen.getByTestId('scroll-area-horizontal-thumb')).toBeInTheDocument() + }) + }) + + it('should render the convenience wrapper and apply slot props', async () => { + render( + <> +

    Installed apps

    + +
    Scrollable content
    +
    + , + ) + + await waitFor(() => { + const root = screen.getByTestId('scroll-area-wrapper-root') + const viewport = screen.getByRole('region', { name: 'Installed apps' }) + const content = screen.getByText('Scrollable content').parentElement + + expect(root).toBeInTheDocument() + expect(viewport).toHaveClass('custom-viewport-class') + expect(viewport).toHaveAccessibleName('Installed apps') + expect(content).toHaveClass('custom-content-class') + expect(screen.getByText('Scrollable content')).toBeInTheDocument() + }) + }) + }) + + describe('Scrollbar', () => { + it('should apply the default vertical scrollbar classes and orientation data attribute', async () => { + renderScrollArea() + + await waitFor(() => { + const scrollbar = screen.getByTestId('scroll-area-vertical-scrollbar') + const thumb = screen.getByTestId('scroll-area-vertical-thumb') + + expect(scrollbar).toHaveAttribute('data-orientation', 'vertical') + expect(scrollbar).toHaveClass(styles.scrollbar) + expect(scrollbar).toHaveClass( + 'flex', + 'overflow-clip', + 'p-1', + 'touch-none', + 'select-none', + 'opacity-100', + 'transition-opacity', + 'motion-reduce:transition-none', + 'pointer-events-none', + 'data-[hovering]:pointer-events-auto', + 'data-[scrolling]:pointer-events-auto', + 'data-[orientation=vertical]:absolute', + 'data-[orientation=vertical]:inset-y-0', + 'data-[orientation=vertical]:w-3', + 'data-[orientation=vertical]:justify-center', + ) + expect(thumb).toHaveAttribute('data-orientation', 'vertical') + expect(thumb).toHaveClass( + 'shrink-0', + 'rounded-[4px]', + 'bg-state-base-handle', + 'transition-[background-color]', + 'motion-reduce:transition-none', + 'data-[orientation=vertical]:w-1', + ) + }) + }) + + it('should apply horizontal scrollbar and thumb classes when orientation is horizontal', async () => { + renderScrollArea() + + await waitFor(() => { + const scrollbar = screen.getByTestId('scroll-area-horizontal-scrollbar') + const thumb = screen.getByTestId('scroll-area-horizontal-thumb') + + expect(scrollbar).toHaveAttribute('data-orientation', 'horizontal') + expect(scrollbar).toHaveClass(styles.scrollbar) + expect(scrollbar).toHaveClass( + 'flex', + 'overflow-clip', + 'p-1', + 'touch-none', + 'select-none', + 'opacity-100', + 'transition-opacity', + 'motion-reduce:transition-none', + 'pointer-events-none', + 'data-[hovering]:pointer-events-auto', + 'data-[scrolling]:pointer-events-auto', + 'data-[orientation=horizontal]:absolute', + 'data-[orientation=horizontal]:inset-x-0', + 'data-[orientation=horizontal]:h-3', + 'data-[orientation=horizontal]:items-center', + ) + expect(thumb).toHaveAttribute('data-orientation', 'horizontal') + expect(thumb).toHaveClass( + 'shrink-0', + 'rounded-[4px]', + 'bg-state-base-handle', + 'transition-[background-color]', + 'motion-reduce:transition-none', + 'data-[orientation=horizontal]:h-1', + ) + }) + }) + }) + + describe('Props', () => { + it('should forward className to the viewport', async () => { + renderScrollArea({ + viewportClassName: 'custom-viewport-class', + }) + + await waitFor(() => { + expect(screen.getByTestId('scroll-area-viewport')).toHaveClass( + 'size-full', + 'min-h-0', + 'min-w-0', + 'outline-none', + 'focus-visible:ring-1', + 'focus-visible:ring-inset', + 'focus-visible:ring-components-input-border-hover', + 'custom-viewport-class', + ) + }) + }) + + it('should let callers control scrollbar inset spacing via margin-based className overrides', async () => { + renderScrollArea({ + verticalScrollbarClassName: 'data-[orientation=vertical]:my-2 data-[orientation=vertical]:[margin-inline-end:-0.75rem]', + horizontalScrollbarClassName: 'data-[orientation=horizontal]:mx-2 data-[orientation=horizontal]:mb-2', + }) + + await waitFor(() => { + expect(screen.getByTestId('scroll-area-vertical-scrollbar')).toHaveClass( + 'data-[orientation=vertical]:my-2', + 'data-[orientation=vertical]:[margin-inline-end:-0.75rem]', + ) + expect(screen.getByTestId('scroll-area-horizontal-scrollbar')).toHaveClass( + 'data-[orientation=horizontal]:mx-2', + 'data-[orientation=horizontal]:mb-2', + ) + }) + }) + }) + + describe('Corner', () => { + it('should render the corner export when both axes overflow', async () => { + const originalDescriptors = { + clientHeight: Object.getOwnPropertyDescriptor(HTMLDivElement.prototype, 'clientHeight'), + clientWidth: Object.getOwnPropertyDescriptor(HTMLDivElement.prototype, 'clientWidth'), + scrollHeight: Object.getOwnPropertyDescriptor(HTMLDivElement.prototype, 'scrollHeight'), + scrollWidth: Object.getOwnPropertyDescriptor(HTMLDivElement.prototype, 'scrollWidth'), + } + + Object.defineProperties(HTMLDivElement.prototype, { + clientHeight: { + configurable: true, + get() { + return this.getAttribute('data-testid') === 'scroll-area-viewport' ? 80 : 0 + }, + }, + clientWidth: { + configurable: true, + get() { + return this.getAttribute('data-testid') === 'scroll-area-viewport' ? 80 : 0 + }, + }, + scrollHeight: { + configurable: true, + get() { + return this.getAttribute('data-testid') === 'scroll-area-viewport' ? 160 : 0 + }, + }, + scrollWidth: { + configurable: true, + get() { + return this.getAttribute('data-testid') === 'scroll-area-viewport' ? 160 : 0 + }, + }, + }) + + try { + render( + + + +
    Scrollable content
    +
    +
    + + + + + + + +
    , + ) + + await waitFor(() => { + expect(screen.getByTestId('scroll-area-corner')).toBeInTheDocument() + expect(screen.getByTestId('scroll-area-corner')).toHaveClass('bg-transparent') + }) + } + finally { + if (originalDescriptors.clientHeight) { + Object.defineProperty(HTMLDivElement.prototype, 'clientHeight', originalDescriptors.clientHeight) + } + if (originalDescriptors.clientWidth) { + Object.defineProperty(HTMLDivElement.prototype, 'clientWidth', originalDescriptors.clientWidth) + } + if (originalDescriptors.scrollHeight) { + Object.defineProperty(HTMLDivElement.prototype, 'scrollHeight', originalDescriptors.scrollHeight) + } + if (originalDescriptors.scrollWidth) { + Object.defineProperty(HTMLDivElement.prototype, 'scrollWidth', originalDescriptors.scrollWidth) + } + } + }) + }) +}) diff --git a/web/app/components/base/ui/scroll-area/index.module.css b/web/app/components/base/ui/scroll-area/index.module.css new file mode 100644 index 0000000000..a81fd3d3c2 --- /dev/null +++ b/web/app/components/base/ui/scroll-area/index.module.css @@ -0,0 +1,75 @@ +.scrollbar::before, +.scrollbar::after { + content: ''; + position: absolute; + z-index: 1; + border-radius: 9999px; + pointer-events: none; + opacity: 0; + transition: opacity 150ms ease; +} + +.scrollbar[data-orientation='vertical']::before { + left: 50%; + top: 4px; + width: 4px; + height: 12px; + transform: translateX(-50%); + background: linear-gradient(to bottom, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='vertical']::after { + left: 50%; + bottom: 4px; + width: 4px; + height: 12px; + transform: translateX(-50%); + background: linear-gradient(to top, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='horizontal']::before { + top: 50%; + left: 4px; + width: 12px; + height: 4px; + transform: translateY(-50%); + background: linear-gradient(to right, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='horizontal']::after { + top: 50%; + right: 4px; + width: 12px; + height: 4px; + transform: translateY(-50%); + background: linear-gradient(to left, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='vertical']:not([data-overflow-y-start])::before { + opacity: 1; +} + +.scrollbar[data-orientation='vertical']:not([data-overflow-y-end])::after { + opacity: 1; +} + +.scrollbar[data-orientation='horizontal']:not([data-overflow-x-start])::before { + opacity: 1; +} + +.scrollbar[data-orientation='horizontal']:not([data-overflow-x-end])::after { + opacity: 1; +} + +.scrollbar[data-hovering] > [data-orientation], +.scrollbar[data-scrolling] > [data-orientation], +.scrollbar > [data-orientation]:active { + background-color: var(--scroll-area-thumb-bg-active, var(--color-state-base-handle-hover)); +} + +@media (prefers-reduced-motion: reduce) { + .scrollbar::before, + .scrollbar::after { + transition: none; + } +} diff --git a/web/app/components/base/ui/scroll-area/index.stories.tsx b/web/app/components/base/ui/scroll-area/index.stories.tsx new file mode 100644 index 0000000000..4a97610c19 --- /dev/null +++ b/web/app/components/base/ui/scroll-area/index.stories.tsx @@ -0,0 +1,712 @@ +import type { Meta, StoryObj } from '@storybook/nextjs-vite' +import type { ReactNode } from 'react' +import * as React from 'react' +import AppIcon from '@/app/components/base/app-icon' +import { cn } from '@/utils/classnames' +import { + ScrollAreaContent, + ScrollAreaCorner, + ScrollAreaRoot, + ScrollAreaScrollbar, + ScrollAreaThumb, + ScrollAreaViewport, +} from '.' + +const meta = { + title: 'Base/Layout/ScrollArea', + component: ScrollAreaRoot, + parameters: { + layout: 'padded', + docs: { + description: { + component: 'Compound scroll container built on Base UI ScrollArea. These stories focus on panel-style compositions that already exist throughout Dify: dense sidebars, sticky list headers, multi-pane workbenches, horizontal rails, and overlay surfaces. Scrollbar placement should be adjusted by consumer spacing classes such as margin-based overrides instead of right/bottom positioning utilities.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +const panelClassName = 'overflow-hidden rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg shadow-shadow-shadow-5' +const blurPanelClassName = 'overflow-hidden rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-xl shadow-shadow-shadow-7 backdrop-blur-[6px]' +const labelClassName = 'text-text-tertiary system-xs-medium-uppercase tracking-[0.14em]' +const titleClassName = 'text-text-primary system-sm-semibold' +const bodyClassName = 'text-text-secondary system-sm-regular' +const insetScrollAreaClassName = 'h-full p-1' +const insetViewportClassName = 'rounded-[20px] bg-components-panel-bg' +const insetScrollbarClassName = 'data-[orientation=vertical]:my-1 data-[orientation=vertical]:[margin-inline-end:0.25rem] data-[orientation=horizontal]:mx-1 data-[orientation=horizontal]:mb-1' +const storyButtonClassName = 'flex w-full items-center justify-between gap-3 rounded-xl border border-divider-subtle bg-components-panel-bg-alt px-3 py-2.5 text-left text-text-secondary transition-colors hover:bg-state-base-hover focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-inset focus-visible:ring-components-input-border-hover motion-reduce:transition-none' +const sidebarScrollAreaClassName = 'h-full' +const sidebarViewportClassName = 'overscroll-contain' +const sidebarContentClassName = 'space-y-0.5' +const sidebarScrollbarClassName = 'data-[orientation=vertical]:my-2 data-[orientation=vertical]:[margin-inline-end:-0.75rem]' +const appNavButtonClassName = 'group flex h-8 w-full items-center justify-between gap-3 rounded-lg px-2 text-left transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-inset focus-visible:ring-components-input-border-hover motion-reduce:transition-none' +const appNavMetaClassName = 'shrink-0 rounded-md border border-divider-subtle bg-components-panel-bg-alt px-1.5 py-0.5 text-text-quaternary system-2xs-medium-uppercase tracking-[0.08em]' + +const releaseRows = [ + { title: 'Agent refactor', meta: 'Updated 2 hours ago', status: 'Ready' }, + { title: 'Retriever tuning', meta: 'Updated yesterday', status: 'Review' }, + { title: 'Workflow replay', meta: 'Updated 3 days ago', status: 'Draft' }, + { title: 'Sandbox policy', meta: 'Updated this week', status: 'Ready' }, + { title: 'SSE diagnostics', meta: 'Updated last week', status: 'Blocked' }, + { title: 'Model routing', meta: 'Updated 9 days ago', status: 'Review' }, + { title: 'Chunk overlap', meta: 'Updated 11 days ago', status: 'Draft' }, + { title: 'Vector warmup', meta: 'Updated 2 weeks ago', status: 'Ready' }, +] as const + +const queueRows = [ + { id: 'PLG-142', title: 'Plugin catalog sync', note: 'Waiting for moderation result' }, + { id: 'OPS-088', title: 'Billing alert fallback', note: 'Last retry finished 12 minutes ago' }, + { id: 'RAG-511', title: 'Embedding migration', note: '16 datasets still pending' }, + { id: 'AGT-204', title: 'Multi-agent tracing', note: 'QA is verifying edge cases' }, + { id: 'UI-390', title: 'Prompt editor polish', note: 'Needs token density pass' }, + { id: 'WEB-072', title: 'Marketplace empty state', note: 'Waiting for design review' }, +] as const + +const horizontalCards = [ + { title: 'Claude Opus', detail: 'Reasoning-heavy preset' }, + { title: 'GPT-5.4', detail: 'Balanced orchestration lane' }, + { title: 'Gemini 2.5', detail: 'Multimodal fallback' }, + { title: 'Qwen Max', detail: 'Regional deployment' }, + { title: 'DeepSeek R1', detail: 'High-throughput analysis' }, + { title: 'Llama 4', detail: 'Cost-sensitive routing' }, +] as const + +const activityRows = Array.from({ length: 14 }, (_, index) => ({ + title: `Workspace activity ${index + 1}`, + body: 'A short line of copy to mimic dense operational feeds in settings and debug panels.', +})) + +const scrollbarShowcaseRows = Array.from({ length: 18 }, (_, index) => ({ + title: `Scroll checkpoint ${index + 1}`, + body: 'Dedicated story content so the scrollbar can be inspected without sticky headers, masks, or clipped shells.', +})) + +const horizontalShowcaseCards = Array.from({ length: 8 }, (_, index) => ({ + title: `Lane ${index + 1}`, + body: 'Horizontal scrollbar reference without edge hints.', +})) + +const webAppsRows = [ + { id: 'invoice-copilot', name: 'Invoice Copilot', meta: 'Pinned', icon: '🧾', iconBackground: '#FFEAD5', selected: true, pinned: true }, + { id: 'rag-ops', name: 'RAG Ops Console', meta: 'Ops', icon: '🛰️', iconBackground: '#E0F2FE', selected: false, pinned: true }, + { id: 'knowledge-studio', name: 'Knowledge Studio', meta: 'Docs', icon: '📚', iconBackground: '#FEF3C7', selected: false, pinned: true }, + { id: 'workflow-studio', name: 'Workflow Studio', meta: 'Build', icon: '🧩', iconBackground: '#E0E7FF', selected: false, pinned: true }, + { id: 'growth-briefs', name: 'Growth Briefs', meta: 'Brief', icon: '📣', iconBackground: '#FCE7F3', selected: false, pinned: true }, + { id: 'agent-playground', name: 'Agent Playground', meta: 'Lab', icon: '🧪', iconBackground: '#DCFCE7', selected: false, pinned: false }, + { id: 'sales-briefing', name: 'Sales Briefing', meta: 'Team', icon: '📈', iconBackground: '#FCE7F3', selected: false, pinned: false }, + { id: 'support-triage', name: 'Support Triage', meta: 'Queue', icon: '🎧', iconBackground: '#EDE9FE', selected: false, pinned: false }, + { id: 'legal-review', name: 'Legal Review', meta: 'Beta', icon: '⚖️', iconBackground: '#FDE68A', selected: false, pinned: false }, + { id: 'release-watcher', name: 'Release Watcher', meta: 'Feed', icon: '🚀', iconBackground: '#DBEAFE', selected: false, pinned: false }, + { id: 'research-hub', name: 'Research Hub', meta: 'Notes', icon: '🔎', iconBackground: '#E0F2FE', selected: false, pinned: false }, + { id: 'field-enablement', name: 'Field Enablement', meta: 'Team', icon: '🧭', iconBackground: '#DCFCE7', selected: false, pinned: false }, + { id: 'brand-monitor', name: 'Brand Monitor', meta: 'Watch', icon: '🪄', iconBackground: '#F3E8FF', selected: false, pinned: false }, + { id: 'finance-ops', name: 'Finance Ops Desk', meta: 'Ops', icon: '💳', iconBackground: '#FEF3C7', selected: false, pinned: false }, + { id: 'security-radar', name: 'Security Radar', meta: 'Risk', icon: '🛡️', iconBackground: '#FEE2E2', selected: false, pinned: false }, + { id: 'partner-portal', name: 'Partner Portal', meta: 'Ext', icon: '🤝', iconBackground: '#DBEAFE', selected: false, pinned: false }, + { id: 'qa-replays', name: 'QA Replays', meta: 'Debug', icon: '🎞️', iconBackground: '#EDE9FE', selected: false, pinned: false }, + { id: 'roadmap-notes', name: 'Roadmap Notes', meta: 'Plan', icon: '🗺️', iconBackground: '#FFEAD5', selected: false, pinned: false }, +] as const + +const StoryCard = ({ + eyebrow, + title, + description, + className, + children, +}: { + eyebrow: string + title: string + description: string + className?: string + children: ReactNode +}) => ( +
    +
    +
    {eyebrow}
    +

    {title}

    +

    {description}

    +
    + {children} +
    +) + +const VerticalPanelPane = () => ( +
    + + + +
    +
    Release board
    +
    Weekly checkpoints
    +

    A simple vertical panel with the default scrollbar skin and no business-specific overrides.

    +
    + {releaseRows.map(item => ( +
    +
    +
    +

    {item.title}

    +

    {item.meta}

    +
    + + {item.status} + +
    +
    + ))} +
    +
    + + + +
    +
    +) + +const StickyListPane = () => ( +
    + + + +
    +
    Sticky header
    +
    +
    +
    Operational queue
    +

    The scrollbar is still the shared base/ui primitive, while the pane adds sticky structure and a viewport mask.

    +
    + + 24 items + +
    +
    +
    + {queueRows.map(item => ( +
    +
    +
    +
    {item.title}
    +
    {item.note}
    +
    + {item.id} +
    +
    + ))} +
    +
    +
    + + + +
    +
    +) + +const WorkbenchPane = ({ + title, + eyebrow, + children, + className, +}: { + title: string + eyebrow: string + children: ReactNode + className?: string +}) => ( +
    + + + +
    +
    {eyebrow}
    +
    {title}
    +
    + {children} +
    +
    + + + +
    +
    +) + +const HorizontalRailPane = () => ( +
    + + + +
    +
    Horizontal rail
    +
    Model lanes
    +

    This pane keeps the default track behavior and only changes the surface layout around it.

    +
    +
    + {horizontalCards.map(card => ( +
    +
    + + + +
    {card.title}
    +
    {card.detail}
    +
    +
    Drag cards into orchestration groups.
    +
    + ))} +
    +
    +
    + + + +
    +
    +) + +const ScrollbarStatePane = ({ + eyebrow, + title, + description, + initialPosition, +}: { + eyebrow: string + title: string + description: string + initialPosition: 'top' | 'middle' | 'bottom' +}) => { + const viewportId = React.useId() + + React.useEffect(() => { + let frameA = 0 + let frameB = 0 + + const syncScrollPosition = () => { + const viewport = document.getElementById(viewportId) + + if (!(viewport instanceof HTMLDivElement)) + return + + const maxScrollTop = Math.max(0, viewport.scrollHeight - viewport.clientHeight) + + if (initialPosition === 'top') + viewport.scrollTop = 0 + + if (initialPosition === 'middle') + viewport.scrollTop = maxScrollTop / 2 + + if (initialPosition === 'bottom') + viewport.scrollTop = maxScrollTop + } + + frameA = requestAnimationFrame(() => { + frameB = requestAnimationFrame(syncScrollPosition) + }) + + return () => { + cancelAnimationFrame(frameA) + cancelAnimationFrame(frameB) + } + }, [initialPosition, viewportId]) + + return ( +
    +
    +
    {eyebrow}
    +
    {title}
    +

    {description}

    +
    +
    + + + + {scrollbarShowcaseRows.map(item => ( +
    +
    {item.title}
    +
    {item.body}
    +
    + ))} +
    +
    + + + +
    +
    +
    + ) +} + +const HorizontalScrollbarShowcasePane = () => ( +
    +
    +
    Horizontal
    +
    Horizontal track reference
    +

    Current design delivery defines the horizontal scrollbar body, but not a horizontal edge hint.

    +
    +
    + + + +
    +
    Horizontal scrollbar
    +
    A clean horizontal pane to inspect thickness, padding, and thumb behavior without extra masks.
    +
    +
    + {horizontalShowcaseCards.map(card => ( +
    +
    {card.title}
    +
    {card.body}
    +
    + ))} +
    +
    +
    + + + +
    +
    +
    +) + +const OverlayPane = () => ( +
    +
    + + + +
    +
    Overlay palette
    +
    Quick actions
    +
    + {activityRows.map(item => ( +
    +
    + + + +
    +
    {item.title}
    +
    {item.body}
    +
    +
    +
    + ))} +
    +
    + + + +
    +
    +
    +) + +const CornerPane = () => ( +
    + + + +
    +
    +
    Corner surface
    +
    Bi-directional inspector canvas
    +

    Both axes overflow here so the corner becomes visible as a deliberate seam between the two tracks.

    +
    + + Always visible + +
    +
    + {Array.from({ length: 12 }, (_, index) => ( +
    +
    + Cell + {' '} + {index + 1} +
    +

    + Wide-and-tall content to force both scrollbars and show the corner treatment clearly. +

    +
    + ))} +
    +
    +
    + + + + + + + +
    +
    +) + +const ExploreSidebarWebAppsPane = () => { + const pinnedAppsCount = webAppsRows.filter(item => item.pinned).length + + return ( +
    +
    +
    +
    +
    + +
    +
    + Explore +
    +
    +
    + +
    +
    +

    + Web Apps +

    + + {webAppsRows.length} + +
    + +
    + + + + {webAppsRows.map((item, index) => ( +
    + + {index === pinnedAppsCount - 1 && index !== webAppsRows.length - 1 && ( +
    + )} +
    + ))} + + + + + + +
    +
    +
    +
    + ) +} + +export const VerticalPanels: Story = { + render: () => ( + +
    + + +
    +
    + ), +} + +export const ThreePaneWorkbench: Story = { + render: () => ( + +
    + +
    + {releaseRows.map(item => ( + + ))} +
    +
    + +
    + {Array.from({ length: 7 }, (_, index) => ( +
    +
    +
    + Section + {' '} + {index + 1} +
    + + Active + +
    +

    + This pane is intentionally long so the default vertical scrollbar sits over a larger editorial surface. +

    +
    + ))} +
    +
    + +
    + {queueRows.map(item => ( +
    +
    {item.id}
    +
    {item.title}
    +
    {item.note}
    +
    + ))} +
    +
    +
    +
    + ), +} + +export const HorizontalAndOverlay: Story = { + render: () => ( +
    + + + + + + +
    + ), +} + +export const CornerSurface: Story = { + render: () => ( + +
    + +
    +
    + ), +} + +export const ExploreSidebarWebApps: Story = { + render: () => ( + +
    + +
    +
    + ), +} + +export const PrimitiveComposition: Story = { + render: () => ( + +
    + + + + {Array.from({ length: 8 }, (_, index) => ( +
    + Primitive row + {' '} + {index + 1} +
    + ))} +
    +
    + + + + + + + +
    +
    +
    + ), +} + +export const ScrollbarDelivery: Story = { + render: () => ( + +
    + + + + +
    +
    + ), +} diff --git a/web/app/components/base/ui/scroll-area/index.tsx b/web/app/components/base/ui/scroll-area/index.tsx new file mode 100644 index 0000000000..b0f85f78d4 --- /dev/null +++ b/web/app/components/base/ui/scroll-area/index.tsx @@ -0,0 +1,132 @@ +'use client' + +import { ScrollArea as BaseScrollArea } from '@base-ui/react/scroll-area' +import * as React from 'react' +import { cn } from '@/utils/classnames' +import styles from './index.module.css' + +export const ScrollAreaRoot = BaseScrollArea.Root +export type ScrollAreaRootProps = React.ComponentPropsWithRef + +export const ScrollAreaContent = BaseScrollArea.Content +export type ScrollAreaContentProps = React.ComponentPropsWithRef + +export type ScrollAreaSlotClassNames = { + viewport?: string + content?: string + scrollbar?: string +} + +export type ScrollAreaProps = Omit & { + children: React.ReactNode + orientation?: 'vertical' | 'horizontal' + slotClassNames?: ScrollAreaSlotClassNames + label?: string + labelledBy?: string +} + +export const scrollAreaScrollbarClassName = cn( + styles.scrollbar, + 'flex touch-none select-none overflow-clip p-1 opacity-100 transition-opacity motion-reduce:transition-none', + 'pointer-events-none data-[hovering]:pointer-events-auto', + 'data-[scrolling]:pointer-events-auto', + 'data-[orientation=vertical]:absolute data-[orientation=vertical]:inset-y-0 data-[orientation=vertical]:w-3 data-[orientation=vertical]:justify-center', + 'data-[orientation=horizontal]:absolute data-[orientation=horizontal]:inset-x-0 data-[orientation=horizontal]:h-3 data-[orientation=horizontal]:items-center', +) + +export const scrollAreaThumbClassName = cn( + 'shrink-0 rounded-[4px] bg-state-base-handle transition-[background-color] motion-reduce:transition-none', + 'data-[orientation=vertical]:w-1', + 'data-[orientation=horizontal]:h-1', +) + +export const scrollAreaViewportClassName = cn( + 'size-full min-h-0 min-w-0 outline-none', + 'focus-visible:ring-1 focus-visible:ring-inset focus-visible:ring-components-input-border-hover', +) + +export const scrollAreaCornerClassName = 'bg-transparent' + +export type ScrollAreaViewportProps = React.ComponentPropsWithRef + +export function ScrollAreaViewport({ + className, + ...props +}: ScrollAreaViewportProps) { + return ( + + ) +} + +export type ScrollAreaScrollbarProps = React.ComponentPropsWithRef + +export function ScrollAreaScrollbar({ + className, + ...props +}: ScrollAreaScrollbarProps) { + return ( + + ) +} + +export type ScrollAreaThumbProps = React.ComponentPropsWithRef + +export function ScrollAreaThumb({ + className, + ...props +}: ScrollAreaThumbProps) { + return ( + + ) +} + +export type ScrollAreaCornerProps = React.ComponentPropsWithRef + +export function ScrollAreaCorner({ + className, + ...props +}: ScrollAreaCornerProps) { + return ( + + ) +} + +export function ScrollArea({ + children, + className, + orientation = 'vertical', + slotClassNames, + label, + labelledBy, + ...props +}: ScrollAreaProps) { + return ( + + + + {children} + + + + + + + ) +} diff --git a/web/app/components/base/ui/toast/__tests__/index.spec.tsx b/web/app/components/base/ui/toast/__tests__/index.spec.tsx index 75364117c3..db6d86719a 100644 --- a/web/app/components/base/ui/toast/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/toast/__tests__/index.spec.tsx @@ -7,27 +7,25 @@ describe('base/ui/toast', () => { vi.clearAllMocks() vi.useFakeTimers({ shouldAdvanceTime: true }) act(() => { - toast.close() + toast.dismiss() }) }) afterEach(() => { act(() => { - toast.close() + toast.dismiss() vi.runOnlyPendingTimers() }) vi.useRealTimers() }) // Core host and manager integration. - it('should render a toast when add is called', async () => { + it('should render a success toast when called through the typed shortcut', async () => { render() act(() => { - toast.add({ - title: 'Saved', + toast.success('Saved', { description: 'Your changes are available now.', - type: 'success', }) }) @@ -47,20 +45,14 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'First toast', - }) + toast('First toast') }) expect(await screen.findByText('First toast')).toBeInTheDocument() act(() => { - toast.add({ - title: 'Second toast', - }) - toast.add({ - title: 'Third toast', - }) + toast('Second toast') + toast('Third toast') }) expect(await screen.findByText('Third toast')).toBeInTheDocument() @@ -74,13 +66,25 @@ describe('base/ui/toast', () => { }) }) + // Neutral calls should map directly to a toast with only a title. + it('should render a neutral toast when called directly', async () => { + render() + + act(() => { + toast('Neutral toast') + }) + + expect(await screen.findByText('Neutral toast')).toBeInTheDocument() + expect(document.body.querySelector('[aria-hidden="true"].i-ri-information-2-fill')).not.toBeInTheDocument() + }) + // Base UI limit should cap the visible stack and mark overflow toasts as limited. it('should mark overflow toasts as limited when the stack exceeds the configured limit', async () => { render() act(() => { - toast.add({ title: 'First toast' }) - toast.add({ title: 'Second toast' }) + toast('First toast') + toast('Second toast') }) expect(await screen.findByText('Second toast')).toBeInTheDocument() @@ -88,13 +92,12 @@ describe('base/ui/toast', () => { }) // Closing should work through the public manager API. - it('should close a toast when close(id) is called', async () => { + it('should dismiss a toast when dismiss(id) is called', async () => { render() let toastId = '' act(() => { - toastId = toast.add({ - title: 'Closable', + toastId = toast('Closable', { description: 'This toast can be removed.', }) }) @@ -102,7 +105,7 @@ describe('base/ui/toast', () => { expect(await screen.findByText('Closable')).toBeInTheDocument() act(() => { - toast.close(toastId) + toast.dismiss(toastId) }) await waitFor(() => { @@ -117,8 +120,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Dismiss me', + toast('Dismiss me', { description: 'Manual dismissal path.', onClose, }) @@ -143,9 +145,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Default timeout', - }) + toast('Default timeout') }) expect(await screen.findByText('Default timeout')).toBeInTheDocument() @@ -170,9 +170,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Configured timeout', - }) + toast('Configured timeout') }) expect(await screen.findByText('Configured timeout')).toBeInTheDocument() @@ -197,8 +195,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Custom timeout', + toast('Custom timeout', { timeout: 1000, }) }) @@ -214,8 +211,7 @@ describe('base/ui/toast', () => { }) act(() => { - toast.add({ - title: 'Persistent', + toast('Persistent', { timeout: 0, }) }) @@ -235,10 +231,8 @@ describe('base/ui/toast', () => { let toastId = '' act(() => { - toastId = toast.add({ - title: 'Loading', + toastId = toast.info('Loading', { description: 'Preparing your data…', - type: 'info', }) }) @@ -264,8 +258,7 @@ describe('base/ui/toast', () => { render() act(() => { - toast.add({ - title: 'Action toast', + toast('Action toast', { actionProps: { children: 'Undo', onClick: onAction, diff --git a/web/app/components/base/ui/toast/index.stories.tsx b/web/app/components/base/ui/toast/index.stories.tsx index 045ca96823..a0dd806d19 100644 --- a/web/app/components/base/ui/toast/index.stories.tsx +++ b/web/app/components/base/ui/toast/index.stories.tsx @@ -57,9 +57,8 @@ const VariantExamples = () => { }, } as const - toast.add({ - type, - ...copy[type], + toast[type](copy[type].title, { + description: copy[type].description, }) } @@ -103,14 +102,16 @@ const StackExamples = () => { title: 'Ready to publish', description: 'The newest toast stays frontmost while older items tuck behind it.', }, - ].forEach(item => toast.add(item)) + ].forEach((item) => { + toast[item.type](item.title, { + description: item.description, + }) + }) } const createBurst = () => { Array.from({ length: 5 }).forEach((_, index) => { - toast.add({ - type: index % 2 === 0 ? 'info' : 'success', - title: `Background task ${index + 1}`, + toast[index % 2 === 0 ? 'info' : 'success'](`Background task ${index + 1}`, { description: 'Use this to inspect how the stack behaves near the host limit.', }) }) @@ -191,16 +192,12 @@ const PromiseExamples = () => { const ActionExamples = () => { const createActionToast = () => { - toast.add({ - type: 'warning', - title: 'Project archived', + toast.warning('Project archived', { description: 'You can restore it from workspace settings for the next 30 days.', actionProps: { children: 'Undo', onClick: () => { - toast.add({ - type: 'success', - title: 'Project restored', + toast.success('Project restored', { description: 'The workspace is active again.', }) }, @@ -209,17 +206,12 @@ const ActionExamples = () => { } const createLongCopyToast = () => { - toast.add({ - type: 'info', - title: 'Knowledge ingestion in progress', + toast.info('Knowledge ingestion in progress', { description: 'This longer example helps validate line wrapping, close button alignment, and action button placement when the content spans multiple rows.', actionProps: { children: 'View details', onClick: () => { - toast.add({ - type: 'info', - title: 'Job details opened', - }) + toast.info('Job details opened') }, }, }) @@ -243,9 +235,7 @@ const ActionExamples = () => { const UpdateExamples = () => { const createUpdatableToast = () => { - const toastId = toast.add({ - type: 'info', - title: 'Import started', + const toastId = toast.info('Import started', { description: 'Preparing assets and metadata for processing.', timeout: 0, }) @@ -261,7 +251,7 @@ const UpdateExamples = () => { } const clearAll = () => { - toast.close() + toast.dismiss() } return ( diff --git a/web/app/components/base/ui/toast/index.tsx b/web/app/components/base/ui/toast/index.tsx index d91648e44a..a3f4e13727 100644 --- a/web/app/components/base/ui/toast/index.tsx +++ b/web/app/components/base/ui/toast/index.tsx @@ -5,6 +5,7 @@ import type { ToastManagerUpdateOptions, ToastObject, } from '@base-ui/react/toast' +import type { ReactNode } from 'react' import { Toast as BaseToast } from '@base-ui/react/toast' import { useTranslation } from 'react-i18next' import { cn } from '@/utils/classnames' @@ -44,6 +45,9 @@ export type ToastUpdateOptions = Omit, 'dat type?: ToastType } +export type ToastOptions = Omit +export type TypedToastOptions = Omit + type ToastPromiseResultOption = string | ToastUpdateOptions | ((value: Value) => string | ToastUpdateOptions) export type ToastPromiseOptions = { @@ -57,6 +61,21 @@ export type ToastHostProps = { limit?: number } +type ToastDismiss = (toastId?: string) => void +type ToastCall = (title: ReactNode, options?: ToastOptions) => string +type TypedToastCall = (title: ReactNode, options?: TypedToastOptions) => string + +export type ToastApi = { + (title: ReactNode, options?: ToastOptions): string + success: TypedToastCall + error: TypedToastCall + warning: TypedToastCall + info: TypedToastCall + dismiss: ToastDismiss + update: (toastId: string, options: ToastUpdateOptions) => void + promise: (promiseValue: Promise, options: ToastPromiseOptions) => Promise +} + const toastManager = BaseToast.createToastManager() function isToastType(type: string): type is ToastType { @@ -67,21 +86,48 @@ function getToastType(type?: string): ToastType | undefined { return type && isToastType(type) ? type : undefined } -export const toast = { - add(options: ToastAddOptions) { - return toastManager.add(options) - }, - close(toastId?: string) { - toastManager.close(toastId) - }, - update(toastId: string, options: ToastUpdateOptions) { - toastManager.update(toastId, options) - }, - promise(promiseValue: Promise, options: ToastPromiseOptions) { - return toastManager.promise(promiseValue, options) - }, +function addToast(options: ToastAddOptions) { + return toastManager.add(options) } +const showToast: ToastCall = (title, options) => addToast({ + ...options, + title, +}) + +const dismissToast: ToastDismiss = (toastId) => { + toastManager.close(toastId) +} + +function createTypedToast(type: ToastType): TypedToastCall { + return (title, options) => addToast({ + ...options, + title, + type, + }) +} + +function updateToast(toastId: string, options: ToastUpdateOptions) { + toastManager.update(toastId, options) +} + +function promiseToast(promiseValue: Promise, options: ToastPromiseOptions) { + return toastManager.promise(promiseValue, options) +} + +export const toast: ToastApi = Object.assign( + showToast, + { + success: createTypedToast('success'), + error: createTypedToast('error'), + warning: createTypedToast('warning'), + info: createTypedToast('info'), + dismiss: dismissToast, + update: updateToast, + promise: promiseToast, + }, +) + function ToastIcon({ type }: { type?: ToastType }) { return type ?
    diff --git a/web/app/components/base/zendesk/index.tsx b/web/app/components/base/zendesk/index.tsx index 4879725c85..20f4f84baf 100644 --- a/web/app/components/base/zendesk/index.tsx +++ b/web/app/components/base/zendesk/index.tsx @@ -1,7 +1,7 @@ -import { headers } from 'next/headers' -import Script from 'next/script' import { memo } from 'react' import { IS_CE_EDITION, IS_PROD, ZENDESK_WIDGET_KEY } from '@/config' +import { headers } from '@/next/headers' +import Script from '@/next/script' const Zendesk = async () => { if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY) diff --git a/web/app/components/billing/partner-stack/__tests__/use-ps-info.spec.tsx b/web/app/components/billing/partner-stack/__tests__/use-ps-info.spec.tsx index ec79d18d29..2ea5db840f 100644 --- a/web/app/components/billing/partner-stack/__tests__/use-ps-info.spec.tsx +++ b/web/app/components/billing/partner-stack/__tests__/use-ps-info.spec.tsx @@ -48,7 +48,7 @@ vi.mock('js-cookie', () => { remove, } }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => ({ get: (key: string) => searchParamsValues[key] ?? null, }), diff --git a/web/app/components/billing/partner-stack/use-ps-info.ts b/web/app/components/billing/partner-stack/use-ps-info.ts index 51d693f358..7c45d7ef87 100644 --- a/web/app/components/billing/partner-stack/use-ps-info.ts +++ b/web/app/components/billing/partner-stack/use-ps-info.ts @@ -1,8 +1,8 @@ import { useBoolean } from 'ahooks' import Cookies from 'js-cookie' -import { useSearchParams } from 'next/navigation' import { useCallback } from 'react' import { PARTNER_STACK_CONFIG } from '@/config' +import { useSearchParams } from '@/next/navigation' import { useBindPartnerStackInfo } from '@/service/use-billing' const usePSInfo = () => { diff --git a/web/app/components/billing/plan/__tests__/index.spec.tsx b/web/app/components/billing/plan/__tests__/index.spec.tsx index 79597b4b22..bed7ebd9fb 100644 --- a/web/app/components/billing/plan/__tests__/index.spec.tsx +++ b/web/app/components/billing/plan/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ let currentPath = '/billing' const push = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push }), usePathname: () => currentPath, })) diff --git a/web/app/components/billing/plan/index.tsx b/web/app/components/billing/plan/index.tsx index 2f953c3a8e..b420110a4d 100644 --- a/web/app/components/billing/plan/index.tsx +++ b/web/app/components/billing/plan/index.tsx @@ -7,7 +7,6 @@ import { RiGroupLine, } from '@remixicon/react' import { useUnmountedRef } from 'ahooks' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useEffect } from 'react' import { useTranslation } from 'react-i18next' @@ -19,6 +18,7 @@ import VerifyStateModal from '@/app/education-apply/verify-state-modal' import { useAppContext } from '@/context/app-context' import { useModalContextSelector } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' +import { usePathname, useRouter } from '@/next/navigation' import { useEducationVerify } from '@/service/use-education' import { getDaysUntilEndOfMonth } from '@/utils/time' import { Loading } from '../../base/icons/src/public/thought' diff --git a/web/app/components/billing/pricing/__tests__/footer.spec.tsx b/web/app/components/billing/pricing/__tests__/footer.spec.tsx index 762d0ad211..9a9215c177 100644 --- a/web/app/components/billing/pricing/__tests__/footer.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/footer.spec.tsx @@ -3,7 +3,7 @@ import * as React from 'react' import Footer from '../footer' import { CategoryEnum } from '../types' -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => ( {children} diff --git a/web/app/components/billing/pricing/__tests__/header.spec.tsx b/web/app/components/billing/pricing/__tests__/header.spec.tsx index 0aadc3b0ce..cb8991ff42 100644 --- a/web/app/components/billing/pricing/__tests__/header.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/header.spec.tsx @@ -1,12 +1,14 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' -import { Dialog } from '@/app/components/base/ui/dialog' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import Header from '../header' function renderHeader(onClose: () => void) { return render( -
    + +
    +
    , ) } @@ -24,7 +26,7 @@ describe('Header', () => { expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) @@ -33,7 +35,7 @@ describe('Header', () => { const handleClose = vi.fn() renderHeader(handleClose) - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' })) expect(handleClose).toHaveBeenCalledTimes(1) }) @@ -41,11 +43,11 @@ describe('Header', () => { describe('Edge Cases', () => { it('should render structural elements with translation keys', () => { - const { container } = renderHeader(vi.fn()) + renderHeader(vi.fn()) - expect(container.querySelector('span')).toBeInTheDocument() - expect(container.querySelector('p')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) }) diff --git a/web/app/components/billing/pricing/__tests__/index.spec.tsx b/web/app/components/billing/pricing/__tests__/index.spec.tsx index 1be2234cf9..a8d0a4329e 100644 --- a/web/app/components/billing/pricing/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/index.spec.tsx @@ -19,7 +19,7 @@ vi.mock('../plans/self-hosted-plan-item/list', () => ({ ), })) -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => (
    {children} @@ -68,6 +68,7 @@ describe('Pricing', () => { it('should render pricing header and localized footer link', () => { render() + expect(screen.getByRole('dialog', { name: 'billing.plansCommon.title.plans' })).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByTestId('pricing-link')).toHaveAttribute('href', 'https://dify.ai/en/pricing#plans-and-features') }) diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 6a213eca00..1422ec1cb1 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -1,7 +1,7 @@ import type { Category } from './types' -import Link from 'next/link' import * as React from 'react' import { useTranslation } from 'react-i18next' +import Link from '@/next/link' import { cn } from '@/utils/classnames' import { CategoryEnum } from './types' @@ -28,8 +28,9 @@ const Footer = ({ {t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })} diff --git a/web/app/components/billing/pricing/header.tsx b/web/app/components/billing/pricing/header.tsx index d0ffe100db..5ab1895667 100644 --- a/web/app/components/billing/pricing/header.tsx +++ b/web/app/components/billing/pricing/header.tsx @@ -1,5 +1,6 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' +import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog' import { cn } from '@/utils/classnames' import Button from '../../base/button' import DifyLogo from '../../base/logo/dify-logo' @@ -18,24 +19,25 @@ const Header = ({
    -
    + - {t('plansCommon.title.plans', { ns: 'billing' })} - +
    -

    + {t('plansCommon.title.description', { ns: 'billing' })} -

    +
    -
    {t(`${i18nPrefix}.description`, { ns: 'billing' })}
    +
    {t(`${i18nPrefix}.description`, { ns: 'billing' })}
    {/* Price */}
    {isFreePlan && ( - {t('plansCommon.free', { ns: 'billing' })} + {t('plansCommon.free', { ns: 'billing' })} )} {!isFreePlan && ( <> {isYear && ( - + $ {planInfo.price * 12} )} - + $ {isYear ? planInfo.price * 10 : planInfo.price} - + {t('plansCommon.priceTip', { ns: 'billing' })} {t(`plansCommon.${!isYear ? 'month' : 'year'}`, { ns: 'billing' })} diff --git a/web/app/components/billing/pricing/plans/self-hosted-plan-item/__tests__/index.spec.tsx b/web/app/components/billing/pricing/plans/self-hosted-plan-item/__tests__/index.spec.tsx index 9507cdef3c..103b188046 100644 --- a/web/app/components/billing/pricing/plans/self-hosted-plan-item/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/plans/self-hosted-plan-item/__tests__/index.spec.tsx @@ -1,8 +1,8 @@ import type { Mock } from 'vitest' import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { useAppContext } from '@/context/app-context' -import Toast from '../../../../../base/toast' import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../../config' import { SelfHostedPlan } from '../../../../type' import SelfHostedPlanItem from '../index' @@ -16,12 +16,6 @@ vi.mock('../list', () => ({ ), })) -vi.mock('../../../../../base/toast', () => ({ - default: { - notify: vi.fn(), - }, -})) - vi.mock('@/context/app-context', () => ({ useAppContext: vi.fn(), })) @@ -35,11 +29,19 @@ vi.mock('../../../assets', () => ({ })) const mockUseAppContext = useAppContext as Mock -const mockToastNotify = Toast.notify as Mock let assignedHref = '' const originalLocation = window.location +const renderWithToastHost = (ui: React.ReactNode) => { + return render( + <> + + {ui} + , + ) +} + beforeAll(() => { Object.defineProperty(window, 'location', { configurable: true, @@ -56,6 +58,7 @@ beforeAll(() => { beforeEach(() => { vi.clearAllMocks() + toast.dismiss() mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true }) assignedHref = '' }) @@ -90,13 +93,10 @@ describe('SelfHostedPlanItem', () => { it('should show toast when non-manager tries to proceed', () => { mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false }) - render() + renderWithToastHost() fireEvent.click(screen.getByRole('button', { name: /billing\.plans\.premium\.btnText/ })) - expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'error', - message: 'billing.buyPermissionDeniedTip', - })) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) it('should redirect to community url when community plan button clicked', () => { diff --git a/web/app/components/billing/pricing/plans/self-hosted-plan-item/index.tsx b/web/app/components/billing/pricing/plans/self-hosted-plan-item/index.tsx index eaee5082ff..e377dcb0d8 100644 --- a/web/app/components/billing/pricing/plans/self-hosted-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/self-hosted-plan-item/index.tsx @@ -4,9 +4,9 @@ import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { Azure, GoogleCloud } from '@/app/components/base/icons/src/public/billing' +import { toast } from '@/app/components/base/ui/toast' import { useAppContext } from '@/context/app-context' import { cn } from '@/utils/classnames' -import Toast from '../../../../base/toast' import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../config' import { SelfHostedPlan } from '../../../type' import { Community, Enterprise, EnterpriseNoise, Premium, PremiumNoise } from '../../assets' @@ -56,11 +56,7 @@ const SelfHostedPlanItem: FC = ({ const handleGetPayUrl = useCallback(() => { // Only workspace manager can buy plan if (!isCurrentWorkspaceManager) { - Toast.notify({ - type: 'error', - message: t('buyPermissionDeniedTip', { ns: 'billing' }), - className: 'z-[1001]', - }) + toast.error(t('buyPermissionDeniedTip', { ns: 'billing' })) return } if (isFreePlan) { @@ -82,18 +78,18 @@ const SelfHostedPlanItem: FC = ({ {/* Noise Effect */} {STYLE_MAP[plan].noise}
    -
    +
    {STYLE_MAP[plan].icon}
    {t(`${i18nPrefix}.name`, { ns: 'billing' })}
    -
    {t(`${i18nPrefix}.description`, { ns: 'billing' })}
    +
    {t(`${i18nPrefix}.description`, { ns: 'billing' })}
    {/* Price */}
    -
    {t(`${i18nPrefix}.price`, { ns: 'billing' })}
    +
    {t(`${i18nPrefix}.price`, { ns: 'billing' })}
    {!isFreePlan && ( - + {t(`${i18nPrefix}.priceTip`, { ns: 'billing' })} )} @@ -114,7 +110,7 @@ const SelfHostedPlanItem: FC = ({
    - + {t('plans.premium.comingSoon', { ns: 'billing' })}
    diff --git a/web/app/components/browser-initializer.tsx b/web/app/components/browser-initializer.tsx deleted file mode 100644 index c2194ca8d4..0000000000 --- a/web/app/components/browser-initializer.tsx +++ /dev/null @@ -1,66 +0,0 @@ -'use client' - -// Polyfill for Array.prototype.toSpliced (ES2023, Chrome 110+) -if (!Array.prototype.toSpliced) { - // eslint-disable-next-line no-extend-native - Array.prototype.toSpliced = function (this: T[], start: number, deleteCount?: number, ...items: T[]): T[] { - const copy = this.slice() - // When deleteCount is undefined (omitted), delete to end; otherwise let splice handle coercion - if (deleteCount === undefined) - copy.splice(start, copy.length - start, ...items) - else - copy.splice(start, deleteCount, ...items) - return copy - } -} - -class StorageMock { - data: Record - - constructor() { - this.data = {} as Record - } - - setItem(name: string, value: string) { - this.data[name] = value - } - - getItem(name: string) { - return this.data[name] || null - } - - removeItem(name: string) { - delete this.data[name] - } - - clear() { - this.data = {} - } -} - -let localStorage, sessionStorage - -try { - localStorage = globalThis.localStorage - sessionStorage = globalThis.sessionStorage -} -catch { - localStorage = new StorageMock() - sessionStorage = new StorageMock() -} - -Object.defineProperty(globalThis, 'localStorage', { - value: localStorage, -}) - -Object.defineProperty(globalThis, 'sessionStorage', { - value: sessionStorage, -}) - -const BrowserInitializer = ({ - children, -}: { children: React.ReactElement }) => { - return children -} - -export default BrowserInitializer diff --git a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx index 1103da3f36..fcaca86e89 100644 --- a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx +++ b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx @@ -1,10 +1,14 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments } from '@/service/knowledge/use-document' import AutoDisabledDocument from '../auto-disabled-document' +const { mockToastSuccess } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), +})) + type AutoDisabledDocumentsResponse = { document_ids: string[] } const createMockQueryResult = ( @@ -26,9 +30,9 @@ vi.mock('@/service/knowledge/use-document', () => ({ useInvalidDisabledDocument: vi.fn(() => mockInvalidDisabledDocument), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, }, })) @@ -134,10 +138,7 @@ describe('AutoDisabledDocument', () => { fireEvent.click(actionButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'success', - message: expect.any(String), - }) + expect(toast.success).toHaveBeenCalledWith(expect.any(String)) }) }) }) diff --git a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx index a67c110849..c6c7e03bd1 100644 --- a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx +++ b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments, useDocumentEnable, useInvalidDisabledDocument } from '@/service/knowledge/use-document' import StatusWithAction from './status-with-action' @@ -23,7 +23,7 @@ const AutoDisabledDocument: FC = ({ const handleEnableDocuments = useCallback(async () => { await enableDocument({ datasetId, documentIds }) invalidDisabledDocument() - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) }, []) if (!hasDisabledDocument || isLoading) return null diff --git a/web/app/components/datasets/common/image-uploader/hooks/__tests__/use-upload.spec.tsx b/web/app/components/datasets/common/image-uploader/hooks/__tests__/use-upload.spec.tsx index f37dbd41f4..47a29fcfa1 100644 --- a/web/app/components/datasets/common/image-uploader/hooks/__tests__/use-upload.spec.tsx +++ b/web/app/components/datasets/common/image-uploader/hooks/__tests__/use-upload.spec.tsx @@ -3,10 +3,14 @@ import type { FileEntity } from '../../types' import { act, fireEvent, render, renderHook, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { FileContextProvider } from '../../store' import { useUpload } from '../use-upload' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + vi.mock('@/service/use-common', () => ({ useFileUploadConfig: vi.fn(() => ({ data: { @@ -17,9 +21,9 @@ vi.mock('@/service/use-common', () => ({ })), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, }, })) @@ -177,10 +181,7 @@ describe('useUpload hook', () => { }) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -204,13 +205,11 @@ describe('useUpload hook', () => { result.current.fileChangeHandle(mockEvent) }) - // Should not show type error for valid image type - type ToastCall = [{ type: string, message: string }] - const mockNotify = vi.mocked(Toast.notify) + // Should not show file-extension error for valid image type + type ToastCall = [string] + const mockNotify = vi.mocked(toast.error) const calls = mockNotify.mock.calls as ToastCall[] - const typeErrorCalls = calls.filter( - (call: ToastCall) => call[0].type === 'error' && call[0].message.includes('Extension'), - ) + const typeErrorCalls = calls.filter(call => call[0].includes('common.fileUploader.fileExtensionNotSupport')) expect(typeErrorCalls.length).toBe(0) }) }) @@ -261,7 +260,7 @@ describe('useUpload hook', () => { }) // Should not throw and not show error - expect(Toast.notify).not.toHaveBeenCalled() + expect(toast.error).not.toHaveBeenCalled() }) it('should handle null files', () => { @@ -314,10 +313,7 @@ describe('useUpload hook', () => { }) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) }) }) @@ -419,10 +415,7 @@ describe('useUpload hook', () => { }) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Upload error', - }) + expect(toast.error).toHaveBeenCalledWith('Upload error') }) }) }) @@ -481,10 +474,7 @@ describe('useUpload hook', () => { }) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Upload error', - }) + expect(toast.error).toHaveBeenCalledWith('Upload error') }) }) }) @@ -522,10 +512,7 @@ describe('useUpload hook', () => { }) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) }) }) @@ -610,10 +597,7 @@ describe('useUpload hook', () => { }) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) // Restore original MockFileReader @@ -773,10 +757,7 @@ describe('useUpload hook', () => { // Should show error toast for invalid file type await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) }) diff --git a/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts index ab7b8cbf28..d262401f4b 100644 --- a/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts +++ b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts @@ -4,7 +4,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { v4 as uuid4 } from 'uuid' import { fileUpload, getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useFileUploadConfig } from '@/service/use-common' import { ACCEPT_TYPES } from '../constants' import { useFileStore } from '../store' @@ -54,9 +54,9 @@ export const useUpload = () => { const showErrorMessage = useCallback((type: 'type' | 'size') => { if (type === 'type') - Toast.notify({ type: 'error', message: t('fileUploader.fileExtensionNotSupport', { ns: 'common' }) }) + toast.error(t('fileUploader.fileExtensionNotSupport', { ns: 'common' })) else - Toast.notify({ type: 'error', message: t('imageUploader.fileSizeLimitExceeded', { ns: 'dataset', size: fileUploadConfig.imageFileSizeLimit }) }) + toast.error(t('imageUploader.fileSizeLimitExceeded', { ns: 'dataset', size: fileUploadConfig.imageFileSizeLimit })) }, [fileUploadConfig, t]) const getValidFiles = useCallback((files: File[]) => { @@ -146,7 +146,7 @@ export const useUpload = () => { }, onErrorCallback: (error?: any) => { const errorMessage = getFileUploadErrorMessage(error, t('fileUploader.uploadFromComputerUploadError', { ns: 'common' }), t) - Toast.notify({ type: 'error', message: errorMessage }) + toast.error(errorMessage) handleUpdateFile({ ...uploadingFile, progress: -1 }) }, }) @@ -188,7 +188,7 @@ export const useUpload = () => { }, onErrorCallback: (error?: any) => { const errorMessage = getFileUploadErrorMessage(error, t('fileUploader.uploadFromComputerUploadError', { ns: 'common' }), t) - Toast.notify({ type: 'error', message: errorMessage }) + toast.error(errorMessage) handleUpdateFile({ ...uploadingFile, progress: -1 }) }, }) @@ -198,7 +198,7 @@ export const useUpload = () => { reader.addEventListener( 'error', () => { - Toast.notify({ type: 'error', message: t('fileUploader.uploadFromComputerReadError', { ns: 'common' }) }) + toast.error(t('fileUploader.uploadFromComputerReadError', { ns: 'common' })) }, false, ) @@ -211,10 +211,7 @@ export const useUpload = () => { if (newFiles.length === 0) return if (files.length + newFiles.length > singleChunkAttachmentLimit) { - Toast.notify({ - type: 'error', - message: t('imageUploader.singleChunkAttachmentLimitTooltip', { ns: 'datasetHitTesting', limit: singleChunkAttachmentLimit }), - }) + toast.error(t('imageUploader.singleChunkAttachmentLimitTooltip', { ns: 'datasetHitTesting', limit: singleChunkAttachmentLimit })) return } for (const file of newFiles) diff --git a/web/app/components/datasets/common/retrieval-param-config/__tests__/index.spec.tsx b/web/app/components/datasets/common/retrieval-param-config/__tests__/index.spec.tsx index f5b41688e1..db6086ca34 100644 --- a/web/app/components/datasets/common/retrieval-param-config/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/__tests__/index.spec.tsx @@ -5,9 +5,9 @@ import { RETRIEVE_METHOD } from '@/types/app' import RetrievalParamConfig from '../index' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: { type: string, message: string }) => mockNotify(params), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: (message: string) => mockNotify(message), }, })) @@ -260,10 +260,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(screen.getByTestId('rerank-switch')) - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'workflow.errorMsg.rerankModelRequired', - }) + expect(mockNotify).toHaveBeenCalledWith('workflow.errorMsg.rerankModelRequired') }) it('should update reranking model on selection', () => { @@ -618,10 +615,7 @@ describe('RetrievalParamConfig', () => { const rerankModelCard = radioCards.find(card => card.getAttribute('data-title') === 'common.modelProvider.rerankModel.key') fireEvent.click(rerankModelCard!) - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'workflow.errorMsg.rerankModelRequired', - }) + expect(mockNotify).toHaveBeenCalledWith('workflow.errorMsg.rerankModelRequired') }) it('should update weights when WeightedScore changes', () => { diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 2414c29a8c..e0fd245ef5 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -11,8 +11,8 @@ import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold import TopKItem from '@/app/components/base/param-item/top-k-item' import RadioCard from '@/app/components/base/radio-card' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' @@ -59,7 +59,7 @@ const RetrievalParamConfig: FC = ({ const handleToggleRerankEnable = useCallback((enable: boolean) => { if (enable && !currentModel) - Toast.notify({ type: 'error', message: t('errorMsg.rerankModelRequired', { ns: 'workflow' }) }) + toast.error(t('errorMsg.rerankModelRequired', { ns: 'workflow' })) onChange({ ...value, reranking_enable: enable, @@ -96,7 +96,7 @@ const RetrievalParamConfig: FC = ({ } } if (v === RerankingModeEnum.RerankingModel && !currentModel) - Toast.notify({ type: 'error', message: t('errorMsg.rerankModelRequired', { ns: 'workflow' }) }) + toast.error(t('errorMsg.rerankModelRequired', { ns: 'workflow' })) onChange(result) } diff --git a/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx b/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx index 19f1f74e1d..7f1bc0e00c 100644 --- a/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx @@ -7,7 +7,7 @@ import Footer from '../footer' let mockSearchParams = new URLSearchParams() const mockReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), useSearchParams: () => mockSearchParams, })) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx index 820332dcc3..7f292c8ff9 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx @@ -8,7 +8,7 @@ import TabItem from '../tab/item' import Uploader from '../uploader' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx index ac56206003..f97b14af0f 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx @@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { CreateFromDSLModalTab, useDSLImport } from '../use-dsl-import' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts index c839fad3a2..ff7aa1cafb 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -1,6 +1,5 @@ 'use client' import { useDebounceFn } from 'ahooks' -import { useRouter } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -10,6 +9,7 @@ import { DSLImportMode, DSLImportStatus, } from '@/models/app' +import { useRouter } from '@/next/navigation' import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline' export enum CreateFromDSLModalTab { diff --git a/web/app/components/datasets/create-from-pipeline/footer.tsx b/web/app/components/datasets/create-from-pipeline/footer.tsx index 23e83d1da3..ae1bb48394 100644 --- a/web/app/components/datasets/create-from-pipeline/footer.tsx +++ b/web/app/components/datasets/create-from-pipeline/footer.tsx @@ -1,8 +1,8 @@ import { RiFileUploadLine } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from '@/next/navigation' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' import Divider from '../../base/divider' import CreateFromDSLModal, { CreateFromDSLModalTab } from './create-options/create-from-dsl-modal' diff --git a/web/app/components/datasets/create-from-pipeline/header.tsx b/web/app/components/datasets/create-from-pipeline/header.tsx index 99738edb08..204b372a1d 100644 --- a/web/app/components/datasets/create-from-pipeline/header.tsx +++ b/web/app/components/datasets/create-from-pipeline/header.tsx @@ -1,7 +1,7 @@ import { RiArrowLeftLine } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useTranslation } from 'react-i18next' +import Link from '@/next/link' import Button from '../../base/button' const Header = () => { diff --git a/web/app/components/datasets/create-from-pipeline/list/__tests__/create-card.spec.tsx b/web/app/components/datasets/create-from-pipeline/list/__tests__/create-card.spec.tsx index 96bc82f010..773e7e7f74 100644 --- a/web/app/components/datasets/create-from-pipeline/list/__tests__/create-card.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/__tests__/create-card.spec.tsx @@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import CreateCard from '../create-card' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush }), })) @@ -13,12 +13,23 @@ vi.mock('@/app/components/base/amplitude', () => ({ trackEvent: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), })) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + error: mockToastError, + }, + } +}) + const mockCreateEmptyDataset = vi.fn() const mockInvalidDatasetList = vi.fn() @@ -37,6 +48,8 @@ vi.mock('@/service/knowledge/use-dataset', () => ({ describe('CreateCard', () => { beforeEach(() => { vi.clearAllMocks() + mockToastSuccess.mockReset() + mockToastError.mockReset() }) describe('Rendering', () => { diff --git a/web/app/components/datasets/create-from-pipeline/list/create-card.tsx b/web/app/components/datasets/create-from-pipeline/list/create-card.tsx index b32a7dba2d..01443b5401 100644 --- a/web/app/components/datasets/create-from-pipeline/list/create-card.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/create-card.tsx @@ -1,10 +1,10 @@ import { RiAddCircleLine } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' +import { useRouter } from '@/next/navigation' import { useCreatePipelineDataset } from '@/service/knowledge/use-create-dataset' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' @@ -20,10 +20,7 @@ const CreateCard = () => { onSuccess: (data) => { if (data) { const { id } = data - Toast.notify({ - type: 'success', - message: t('creation.successTip', { ns: 'datasetPipeline' }), - }) + toast.success(t('creation.successTip', { ns: 'datasetPipeline' })) invalidDatasetList() trackEvent('create_datasets_from_scratch', { dataset_id: id, @@ -32,10 +29,7 @@ const CreateCard = () => { } }, onError: () => { - Toast.notify({ - type: 'error', - message: t('creation.errorTip', { ns: 'datasetPipeline' }), - }) + toast.error(t('creation.errorTip', { ns: 'datasetPipeline' })) }, }) }, [createEmptyDataset, push, invalidDatasetList, t]) diff --git a/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/edit-pipeline-info.spec.tsx b/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/edit-pipeline-info.spec.tsx index 9c9c80c902..d7f990aa82 100644 --- a/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/edit-pipeline-info.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/edit-pipeline-info.spec.tsx @@ -1,8 +1,6 @@ import type { PipelineTemplate } from '@/models/pipeline' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' - -import Toast from '@/app/components/base/toast' import { ChunkingMode } from '@/models/datasets' import EditPipelineInfo from '../edit-pipeline-info' @@ -16,12 +14,21 @@ vi.mock('@/service/use-pipeline', () => ({ useInvalidCustomizedTemplateList: () => mockInvalidCustomizedTemplateList, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), })) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) + // Mock AppIconPicker to capture interactions let _mockOnSelect: ((icon: { type: 'emoji' | 'image', icon?: string, background?: string, fileId?: string, url?: string }) => void) | undefined let _mockOnClose: (() => void) | undefined @@ -88,6 +95,7 @@ describe('EditPipelineInfo', () => { beforeEach(() => { vi.clearAllMocks() + mockToastError.mockReset() _mockOnSelect = undefined _mockOnClose = undefined }) @@ -235,10 +243,7 @@ describe('EditPipelineInfo', () => { fireEvent.click(saveButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Please enter a name for the Knowledge Base.', - }) + expect(mockToastError).toHaveBeenCalledWith('datasetPipeline.editPipelineInfoNameRequired') }) }) diff --git a/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/index.spec.tsx b/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/index.spec.tsx index 4455672383..4ce4ecdb87 100644 --- a/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/template-card/__tests__/index.spec.tsx @@ -1,12 +1,11 @@ import type { PipelineTemplate } from '@/models/pipeline' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import { ChunkingMode } from '@/models/datasets' import TemplateCard from '../index' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush }), })) @@ -15,12 +14,23 @@ vi.mock('@/app/components/base/amplitude', () => ({ trackEvent: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), })) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + error: mockToastError, + }, + } +}) + // Mock download utilities vi.mock('@/utils/download', () => ({ downloadBlob: vi.fn(), @@ -174,6 +184,8 @@ describe('TemplateCard', () => { beforeEach(() => { vi.clearAllMocks() + mockToastSuccess.mockReset() + mockToastError.mockReset() mockIsExporting = false _capturedOnConfirm = undefined _capturedOnCancel = undefined @@ -228,10 +240,7 @@ describe('TemplateCard', () => { fireEvent.click(chooseButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -291,10 +300,7 @@ describe('TemplateCard', () => { fireEvent.click(chooseButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'success', - message: expect.any(String), - }) + expect(mockToastSuccess).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -309,10 +315,7 @@ describe('TemplateCard', () => { fireEvent.click(chooseButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) }) }) @@ -458,10 +461,7 @@ describe('TemplateCard', () => { fireEvent.click(exportButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'success', - message: expect.any(String), - }) + expect(mockToastSuccess).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -476,10 +476,7 @@ describe('TemplateCard', () => { fireEvent.click(exportButton) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) }) diff --git a/web/app/components/datasets/create-from-pipeline/list/template-card/edit-pipeline-info.tsx b/web/app/components/datasets/create-from-pipeline/list/template-card/edit-pipeline-info.tsx index 69f8f470d0..ae34015559 100644 --- a/web/app/components/datasets/create-from-pipeline/list/template-card/edit-pipeline-info.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/template-card/edit-pipeline-info.tsx @@ -9,7 +9,7 @@ import AppIconPicker from '@/app/components/base/app-icon-picker' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useInvalidCustomizedTemplateList, useUpdateTemplateInfo } from '@/service/use-pipeline' type EditPipelineInfoProps = { @@ -67,10 +67,7 @@ const EditPipelineInfo = ({ const handleSave = useCallback(async () => { if (!name) { - Toast.notify({ - type: 'error', - message: 'Please enter a name for the Knowledge Base.', - }) + toast.error(t('editPipelineInfoNameRequired', { ns: 'datasetPipeline' })) return } const request = { diff --git a/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx b/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx index b3395a83d5..d7881708d6 100644 --- a/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx @@ -1,13 +1,13 @@ import type { PipelineTemplate } from '@/models/pipeline' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' import Confirm from '@/app/components/base/confirm' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' +import { useRouter } from '@/next/navigation' import { useCreatePipelineDatasetFromCustomized } from '@/service/knowledge/use-create-dataset' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { @@ -50,10 +50,7 @@ const TemplateCard = ({ const handleUseTemplate = useCallback(async () => { const { data: pipelineTemplateInfo } = await getPipelineTemplateInfo() if (!pipelineTemplateInfo) { - Toast.notify({ - type: 'error', - message: t('creation.errorTip', { ns: 'datasetPipeline' }), - }) + toast.error(t('creation.errorTip', { ns: 'datasetPipeline' })) return } const request = { @@ -61,10 +58,7 @@ const TemplateCard = ({ } await createDataset(request, { onSuccess: async (newDataset) => { - Toast.notify({ - type: 'success', - message: t('creation.successTip', { ns: 'datasetPipeline' }), - }) + toast.success(t('creation.successTip', { ns: 'datasetPipeline' })) invalidDatasetList() if (newDataset.pipeline_id) await handleCheckPluginDependencies(newDataset.pipeline_id, true) @@ -76,10 +70,7 @@ const TemplateCard = ({ push(`/datasets/${newDataset.dataset_id}/pipeline`) }, onError: () => { - Toast.notify({ - type: 'error', - message: t('creation.errorTip', { ns: 'datasetPipeline' }), - }) + toast.error(t('creation.errorTip', { ns: 'datasetPipeline' })) }, }) }, [getPipelineTemplateInfo, createDataset, t, handleCheckPluginDependencies, push, invalidDatasetList, pipeline.name, pipeline.id, type]) @@ -109,16 +100,10 @@ const TemplateCard = ({ onSuccess: (res) => { const blob = new Blob([res.data], { type: 'application/yaml' }) downloadBlob({ data: blob, fileName: `${pipeline.name}.pipeline` }) - Toast.notify({ - type: 'success', - message: t('exportDSL.successTip', { ns: 'datasetPipeline' }), - }) + toast.success(t('exportDSL.successTip', { ns: 'datasetPipeline' })) }, onError: () => { - Toast.notify({ - type: 'error', - message: t('exportDSL.errorTip', { ns: 'datasetPipeline' }), - }) + toast.error(t('exportDSL.errorTip', { ns: 'datasetPipeline' })) }, }) }, [t, isExporting, pipeline.id, pipeline.name, exportPipelineDSL]) diff --git a/web/app/components/datasets/create/__tests__/index.spec.tsx b/web/app/components/datasets/create/__tests__/index.spec.tsx index 793bc21344..59d5dd891a 100644 --- a/web/app/components/datasets/create/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/__tests__/index.spec.tsx @@ -24,7 +24,7 @@ const IndexingTypeValues = { } // Mock next/link -vi.mock('next/link', () => { +vi.mock('@/next/link', () => { return function MockLink({ children, href }: { children: React.ReactNode, href: string }) { return
    {children} } diff --git a/web/app/components/datasets/create/embedding-process/__tests__/index.spec.tsx b/web/app/components/datasets/create/embedding-process/__tests__/index.spec.tsx index 686139250a..d1787fc47a 100644 --- a/web/app/components/datasets/create/embedding-process/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/embedding-process/__tests__/index.spec.tsx @@ -16,7 +16,7 @@ import { const mockPush = vi.fn() const mockRouter = { push: mockPush } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => mockRouter, })) diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index e9cea84f00..812eb2e51c 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -6,8 +6,6 @@ import { RiLoader2Fill, RiTerminalBoxLine, } from '@remixicon/react' -import Link from 'next/link' -import { useRouter } from 'next/navigation' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -15,6 +13,8 @@ import Divider from '@/app/components/base/divider' import { Plan } from '@/app/components/billing/type' import { useProviderContext } from '@/context/provider-context' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' +import Link from '@/next/link' +import { useRouter } from '@/next/navigation' import { useProcessRule } from '@/service/knowledge/use-dataset' import { useInvalidDocumentList } from '@/service/knowledge/use-document' import IndexingProgressItem from './indexing-progress-item' diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/__tests__/index.spec.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/__tests__/index.spec.tsx index f5379bc543..2df124d7b6 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ import EmptyDatasetCreationModal from '../index' // Mock Next.js router const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index 0a4064de2a..b417c15e8f 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -1,5 +1,4 @@ 'use client' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' @@ -9,6 +8,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { ToastContext } from '@/app/components/base/toast/context' +import { useRouter } from '@/next/navigation' import { createEmptyDataset } from '@/service/datasets' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' diff --git a/web/app/components/datasets/create/file-uploader/__tests__/index.spec.tsx b/web/app/components/datasets/create/file-uploader/__tests__/index.spec.tsx index da337efce2..c0635bebd1 100644 --- a/web/app/components/datasets/create/file-uploader/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/file-uploader/__tests__/index.spec.tsx @@ -58,7 +58,7 @@ vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ })) // Mock SimplePieChart -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: () => { const Component = ({ percentage }: { percentage: number }) => (
    diff --git a/web/app/components/datasets/create/file-uploader/components/__tests__/file-list-item.spec.tsx b/web/app/components/datasets/create/file-uploader/components/__tests__/file-list-item.spec.tsx index dd88af4395..e7a25cbdd8 100644 --- a/web/app/components/datasets/create/file-uploader/components/__tests__/file-list-item.spec.tsx +++ b/web/app/components/datasets/create/file-uploader/components/__tests__/file-list-item.spec.tsx @@ -17,7 +17,7 @@ vi.mock('@/types/app', () => ({ })) // Mock SimplePieChart with dynamic import handling -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: () => { const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => (
    diff --git a/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx b/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx index d36773fa5c..2f51a9f767 100644 --- a/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx +++ b/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx @@ -1,10 +1,10 @@ 'use client' import type { CustomFile as File, FileItem } from '@/models/datasets' import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react' -import dynamic from 'next/dynamic' import { useMemo } from 'react' import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' import useTheme from '@/hooks/use-theme' +import dynamic from '@/next/dynamic' import { Theme } from '@/types/app' import { formatFileSize, getFileExtension } from '@/utils/format' import { PROGRESS_COMPLETE, PROGRESS_ERROR } from '../constants' diff --git a/web/app/components/datasets/create/step-two/components/__tests__/indexing-mode-section.spec.tsx b/web/app/components/datasets/create/step-two/components/__tests__/indexing-mode-section.spec.tsx index 43a944dcd4..e46ff6d484 100644 --- a/web/app/components/datasets/create/step-two/components/__tests__/indexing-mode-section.spec.tsx +++ b/web/app/components/datasets/create/step-two/components/__tests__/indexing-mode-section.spec.tsx @@ -6,7 +6,7 @@ import { ChunkingMode } from '@/models/datasets' import { IndexingType } from '../../hooks' import { IndexingModeSection } from '../indexing-mode-section' -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, ...props }: { children?: React.ReactNode, href?: string, className?: string }) => {children}, })) diff --git a/web/app/components/datasets/create/step-two/components/indexing-mode-section.tsx b/web/app/components/datasets/create/step-two/components/indexing-mode-section.tsx index da309348cc..8b49a00500 100644 --- a/web/app/components/datasets/create/step-two/components/indexing-mode-section.tsx +++ b/web/app/components/datasets/create/step-two/components/indexing-mode-section.tsx @@ -3,7 +3,6 @@ import type { FC } from 'react' import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { RetrievalConfig } from '@/types/app' -import Link from 'next/link' import { useTranslation } from 'react-i18next' import Badge from '@/app/components/base/badge' import Button from '@/app/components/base/button' @@ -16,6 +15,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useDocLink } from '@/context/i18n' import { ChunkingMode } from '@/models/datasets' +import Link from '@/next/link' import { cn } from '@/utils/classnames' import { indexMethodIcon } from '../../icons' import { IndexingType } from '../hooks' diff --git a/web/app/components/datasets/create/top-bar/__tests__/index.spec.tsx b/web/app/components/datasets/create/top-bar/__tests__/index.spec.tsx index 4fc8d1852b..c038a371d6 100644 --- a/web/app/components/datasets/create/top-bar/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/top-bar/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { render, screen } from '@testing-library/react' import { TopBar } from '../index' // Mock next/link to capture href values -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, replace, className }: { children: React.ReactNode, href: string, replace?: boolean, className?: string }) => ( {children} diff --git a/web/app/components/datasets/create/top-bar/index.tsx b/web/app/components/datasets/create/top-bar/index.tsx index 0051430511..ba4c49e300 100644 --- a/web/app/components/datasets/create/top-bar/index.tsx +++ b/web/app/components/datasets/create/top-bar/index.tsx @@ -1,9 +1,9 @@ import type { FC } from 'react' import type { StepperProps } from '../stepper' import { RiArrowLeftLine } from '@remixicon/react' -import Link from 'next/link' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' +import Link from '@/next/link' import { cn } from '@/utils/classnames' import { Stepper } from '../stepper' diff --git a/web/app/components/datasets/documents/__tests__/index.spec.tsx b/web/app/components/datasets/documents/__tests__/index.spec.tsx index f464c97395..2dd91dd7f3 100644 --- a/web/app/components/datasets/documents/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/__tests__/index.spec.tsx @@ -13,7 +13,7 @@ type MockState = Parameters[0] // Mock Next.js router const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, replace: vi.fn(), diff --git a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx index 5422c23b9a..ce73368e1a 100644 --- a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx @@ -4,7 +4,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import Operations from '../operations' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/documents/components/__tests__/rename-modal.spec.tsx b/web/app/components/datasets/documents/components/__tests__/rename-modal.spec.tsx index 9ed61a66e0..ad40e52752 100644 --- a/web/app/components/datasets/documents/components/__tests__/rename-modal.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/rename-modal.spec.tsx @@ -5,11 +5,23 @@ import { renameDocumentName } from '@/service/datasets' import RenameModal from '../rename-modal' +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) + // Mock the service vi.mock('@/service/datasets', () => ({ renameDocumentName: vi.fn(), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, + error: mockToastError, + }, +})) + const mockRenameDocumentName = vi.mocked(renameDocumentName) describe('RenameModal', () => { @@ -118,6 +130,7 @@ describe('RenameModal', () => { await waitFor(() => { expect(handleSaved).toHaveBeenCalledTimes(1) expect(handleClose).toHaveBeenCalledTimes(1) + expect(mockToastSuccess).toHaveBeenCalledWith(expect.any(String)) }) }) }) @@ -163,6 +176,7 @@ describe('RenameModal', () => { // onSaved and onClose should not be called on error expect(handleSaved).not.toHaveBeenCalled() expect(handleClose).not.toHaveBeenCalled() + expect(mockToastError).toHaveBeenCalledWith('Error: API Error') }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx index 279c85f2f0..48e6b58766 100644 --- a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx @@ -9,7 +9,7 @@ import DocumentList from '../../list' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx index 1c5145f7ed..d5e4f480be 100644 --- a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx @@ -9,7 +9,7 @@ import DocumentTableRow from '../document-table-row' const mockPush = vi.fn() let mockSearchParams = '' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx b/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx index 3694b81138..c5f0f0af37 100644 --- a/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx @@ -1,7 +1,6 @@ import type { FC } from 'react' import type { SimpleDocumentDetail } from '@/models/datasets' import { pick } from 'es-toolkit/object' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -13,6 +12,7 @@ import SummaryStatus from '@/app/components/datasets/documents/detail/completed/ import StatusItem from '@/app/components/datasets/documents/status-item' import useTimestamp from '@/hooks/use-timestamp' import { DataSourceType } from '@/models/datasets' +import { useRouter, useSearchParams } from '@/next/navigation' import { formatNumber } from '@/utils/format' import DocumentSourceIcon from './document-source-icon' import { renderTdValue } from './utils' diff --git a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts index 5f48be084e..449478eb7b 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts @@ -3,6 +3,11 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { DocumentActionType } from '@/models/datasets' import { useDocumentActions } from '../use-document-actions' +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) + const mockArchive = vi.fn() const mockSummary = vi.fn() const mockEnable = vi.fn() @@ -22,9 +27,11 @@ vi.mock('@/service/knowledge/use-document', () => ({ useDocumentDownloadZip: () => ({ mutateAsync: mockDownloadZip, isPending: mockIsDownloadingZip }), })) -const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, + error: mockToastError, + }, })) const mockDownloadBlob = vi.fn() @@ -67,9 +74,7 @@ describe('useDocumentActions', () => { datasetId: 'ds-1', documentIds: ['doc-1', 'doc-2'], }) - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'success' }), - ) + expect(mockToastSuccess).toHaveBeenCalledWith(expect.any(String)) expect(defaultOptions.onUpdate).toHaveBeenCalled() }) @@ -142,9 +147,7 @@ describe('useDocumentActions', () => { await result.current.handleAction(DocumentActionType.archive)() }) - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) expect(defaultOptions.onUpdate).not.toHaveBeenCalled() }) }) @@ -174,9 +177,7 @@ describe('useDocumentActions', () => { await result.current.handleBatchReIndex() }) - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -210,9 +211,7 @@ describe('useDocumentActions', () => { await result.current.handleBatchDownload() }) - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) it('should show error toast when blob is null', async () => { @@ -223,9 +222,7 @@ describe('useDocumentActions', () => { await result.current.handleBatchDownload() }) - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts index 56553faa9e..8b6c40e2be 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts @@ -1,7 +1,7 @@ import type { CommonResponse } from '@/models/common' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { DocumentActionType } from '@/models/datasets' import { useDocumentArchive, @@ -79,11 +79,11 @@ export const useDocumentActions = ({ if (!e) { if (actionName === DocumentActionType.delete) onClearSelection() - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) onUpdate() } else { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) } } }, [actionMutationMap, datasetId, selectedIds, onClearSelection, onUpdate, t]) @@ -94,11 +94,11 @@ export const useDocumentActions = ({ ) if (!e) { onClearSelection() - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) onUpdate() } else { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) } }, [retryIndexDocument, datasetId, selectedIds, onClearSelection, onUpdate, t]) @@ -110,7 +110,7 @@ export const useDocumentActions = ({ requestDocumentsZip({ datasetId, documentIds: downloadableSelectedIds }), ) if (e || !blob) { - Toast.notify({ type: 'error', message: t('actionMsg.downloadUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.downloadUnsuccessfully', { ns: 'common' })) return } diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index 84e16c7c48..ff3563c3fe 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -14,7 +14,6 @@ import { } from '@remixicon/react' import { useBoolean, useDebounceFn } from 'ahooks' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -28,6 +27,7 @@ import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' +import { useRouter } from '@/next/navigation' import { useDocumentArchive, useDocumentDelete, diff --git a/web/app/components/datasets/documents/components/rename-modal.tsx b/web/app/components/datasets/documents/components/rename-modal.tsx index a119a2da9e..364aaf48e6 100644 --- a/web/app/components/datasets/documents/components/rename-modal.tsx +++ b/web/app/components/datasets/documents/components/rename-modal.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { renameDocumentName } from '@/service/datasets' type Props = { @@ -41,13 +41,13 @@ const RenameModal: FC = ({ documentId, name: newName, }) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) onSaved() onClose() } catch (error) { if (error) - Toast.notify({ type: 'error', message: error.toString() }) + toast.error(error.toString()) } finally { setSaveLoadingFalse() diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx index 0096dc8c29..8a2e251770 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx @@ -90,7 +90,7 @@ vi.mock('@/app/components/base/amplitude', () => ({ trackEvent: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ datasetId: 'test-dataset-id' }), useRouter: () => ({ push: vi.fn(), @@ -101,7 +101,7 @@ vi.mock('next/navigation', () => ({ })) // Mock next/link -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href }: { children: React.ReactNode, href: string }) => ( {children} ), diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/left-header.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/left-header.spec.tsx index 584c21e826..c4ddec7434 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/left-header.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/left-header.spec.tsx @@ -3,11 +3,11 @@ import { render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import LeftHeader from '../left-header' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ datasetId: 'test-ds-id' }), })) -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href }: { children: React.ReactNode, href: string }) => ( {children} ), diff --git a/web/app/components/datasets/documents/create-from-pipeline/actions/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/actions/__tests__/index.spec.tsx index 45ecaa7e9b..93861ef76a 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/actions/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/actions/__tests__/index.spec.tsx @@ -4,12 +4,12 @@ import Actions from '../index' // Mock next/navigation - useParams returns datasetId const mockDatasetId = 'test-dataset-id' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ datasetId: mockDatasetId }), })) // Mock next/link to capture href -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, replace }: { children: React.ReactNode, href: string, replace?: boolean }) => ( {children} diff --git a/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx index de0609b4d8..dab76da832 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx @@ -1,11 +1,11 @@ import { RiArrowRightLine } from '@remixicon/react' -import Link from 'next/link' -import { useParams } from 'next/navigation' import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Checkbox from '@/app/components/base/checkbox' +import Link from '@/next/link' +import { useParams } from '@/next/navigation' type ActionsProps = { disabled?: boolean diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/__tests__/index.spec.tsx index 87010638b2..4ec21ab1fb 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/__tests__/index.spec.tsx @@ -26,7 +26,7 @@ vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ })) // Mock SimplePieChart -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: () => { const Component = ({ percentage }: { percentage: number }) => (
    diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/__tests__/file-list-item.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/__tests__/file-list-item.spec.tsx index df7fe3540b..fcb0878978 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/__tests__/file-list-item.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/__tests__/file-list-item.spec.tsx @@ -17,7 +17,7 @@ vi.mock('@/types/app', () => ({ })) // Mock SimplePieChart with dynamic import handling -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: () => { const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => (
    diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx index 1a61fa04f0..4338dd05d4 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx @@ -1,10 +1,10 @@ import type { CustomFile as File, FileItem } from '@/models/datasets' import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react' -import dynamic from 'next/dynamic' import { useMemo } from 'react' import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' import { getFileType } from '@/app/components/datasets/common/image-uploader/utils' import useTheme from '@/hooks/use-theme' +import dynamic from '@/next/dynamic' import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx index 894ee60060..6be0e28d31 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx @@ -32,16 +32,21 @@ vi.mock('@/service/base', () => ({ ssePost: mockSsePost, })) -// Mock Toast.notify - static method that manipulates DOM, needs mocking to verify calls -const { mockToastNotify } = vi.hoisted(() => ({ - mockToastNotify: vi.fn(), +// Mock toast.error because the component reports errors through the UI toast manager. +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: mockToastNotify, - }, -})) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) // Mock useGetDataSourceAuth - API service hook requires mocking const { mockUseGetDataSourceAuth } = vi.hoisted(() => ({ @@ -192,6 +197,7 @@ const createDefaultProps = (overrides?: Partial): OnlineDo describe('OnlineDocuments', () => { beforeEach(() => { vi.clearAllMocks() + mockToastError.mockReset() // Reset store state mockStoreState.documentsData = [] @@ -509,10 +515,7 @@ describe('OnlineDocuments', () => { render() await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'Something went wrong', - }) + expect(mockToastError).toHaveBeenCalledWith('Something went wrong') }) }) @@ -774,10 +777,7 @@ describe('OnlineDocuments', () => { render() await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'API Error Message', - }) + expect(mockToastError).toHaveBeenCalledWith('API Error Message') }) }) @@ -1094,10 +1094,7 @@ describe('OnlineDocuments', () => { render() await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'Failed to fetch documents', - }) + expect(mockToastError).toHaveBeenCalledWith('Failed to fetch documents') }) // Should still show loading since documentsData is empty diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx index 4bdaac895b..15b9ee7332 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx @@ -5,7 +5,7 @@ import { useCallback, useEffect, useMemo } from 'react' import { useShallow } from 'zustand/react/shallow' import Loading from '@/app/components/base/loading' import SearchInput from '@/app/components/base/notion-page-selector/search-input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' @@ -96,10 +96,7 @@ const OnlineDocuments = ({ setDocumentsData(documentsData.data as DataSourceNotionWorkspace[]) }, onDataSourceNodeError: (error: DataSourceNodeErrorResponse) => { - Toast.notify({ - type: 'error', - message: error.error, - }) + toast.error(error.error) }, }, ) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx index 1721b72e1c..7c1941afd9 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx @@ -45,15 +45,20 @@ vi.mock('@/service/use-datasource', () => ({ useGetDataSourceAuth: mockUseGetDataSourceAuth, })) -const { mockToastNotify } = vi.hoisted(() => ({ - mockToastNotify: vi.fn(), +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: mockToastNotify, - }, -})) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) // Note: zustand/react/shallow useShallow is imported directly (simple utility function) @@ -231,6 +236,7 @@ const resetMockStoreState = () => { describe('OnlineDrive', () => { beforeEach(() => { vi.clearAllMocks() + mockToastError.mockReset() // Reset store state resetMockStoreState() @@ -541,10 +547,7 @@ describe('OnlineDrive', () => { render() await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith({ - type: 'error', - message: errorMessage, - }) + expect(mockToastError).toHaveBeenCalledWith(errorMessage) }) }) }) @@ -915,10 +918,7 @@ describe('OnlineDrive', () => { render() await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith({ - type: 'error', - message: errorMessage, - }) + expect(mockToastError).toHaveBeenCalledWith(errorMessage) }) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx index 4346a2d0af..2113e8841c 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx @@ -4,7 +4,7 @@ import type { DataSourceNodeCompletedResponse, DataSourceNodeErrorResponse } fro import { produce } from 'immer' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useShallow } from 'zustand/react/shallow' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' @@ -105,10 +105,7 @@ const OnlineDrive = ({ isLoadingRef.current = false }, onDataSourceNodeError: (error: DataSourceNodeErrorResponse) => { - Toast.notify({ - type: 'error', - message: error.error, - }) + toast.error(error.error) setIsLoading(false) isLoadingRef.current = false }, diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx index c147e969a6..cea569fa5f 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx @@ -1,13 +1,26 @@ -import type { MockInstance } from 'vitest' import type { RAGPipelineVariables } from '@/models/pipeline' import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import { BaseFieldType } from '@/app/components/base/form/form-scenarios/base/types' -import Toast from '@/app/components/base/toast' import { CrawlStep } from '@/models/datasets' import { PipelineInputVarType } from '@/models/pipeline' import Options from '../index' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) + // Mock useInitialData and useConfigurations hooks const { mockUseInitialData, mockUseConfigurations } = vi.hoisted(() => ({ mockUseInitialData: vi.fn(), @@ -116,13 +129,9 @@ const createDefaultProps = (overrides?: Partial): OptionsProps => }) describe('Options', () => { - let toastNotifySpy: MockInstance - beforeEach(() => { vi.clearAllMocks() - - // Spy on Toast.notify instead of mocking the entire module - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + mockToastError.mockReset() // Reset mock form values Object.keys(mockFormValues).forEach(key => delete mockFormValues[key]) @@ -132,10 +141,6 @@ describe('Options', () => { mockUseConfigurations.mockReturnValue([createMockConfiguration()]) }) - afterEach(() => { - toastNotifySpy.mockRestore() - }) - describe('Rendering', () => { it('should render without crashing', () => { const props = createDefaultProps() @@ -638,11 +643,7 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) // Assert - Toast should be called with error message - expect(toastNotifySpy).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(mockToastError).toHaveBeenCalled() }) it('should handle validation error and display field name in message', () => { @@ -660,12 +661,7 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) // Assert - Toast message should contain field path - expect(toastNotifySpy).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - message: expect.stringContaining('email_address'), - }), - ) + expect(mockToastError).toHaveBeenCalledWith(expect.stringContaining('email_address')) }) it('should handle empty variables gracefully', () => { @@ -714,12 +710,8 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) // Assert - Toast should be called once (only first error) - expect(toastNotifySpy).toHaveBeenCalledTimes(1) - expect(toastNotifySpy).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(mockToastError).toHaveBeenCalledTimes(1) + expect(mockToastError).toHaveBeenCalled() }) it('should handle validation pass when all required fields have values', () => { @@ -738,7 +730,7 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) // Assert - No toast error, onSubmit called - expect(toastNotifySpy).not.toHaveBeenCalled() + expect(mockToastError).not.toHaveBeenCalled() expect(mockOnSubmit).toHaveBeenCalled() }) @@ -835,7 +827,7 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) expect(mockOnSubmit).toHaveBeenCalled() - expect(toastNotifySpy).not.toHaveBeenCalled() + expect(mockToastError).not.toHaveBeenCalled() }) it('should fail validation with invalid data', () => { @@ -854,7 +846,7 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) expect(mockOnSubmit).not.toHaveBeenCalled() - expect(toastNotifySpy).toHaveBeenCalled() + expect(mockToastError).toHaveBeenCalled() }) it('should show error toast message when validation fails', () => { @@ -871,12 +863,7 @@ describe('Options', () => { fireEvent.click(screen.getByRole('button')) - expect(toastNotifySpy).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - message: expect.any(String), - }), - ) + expect(mockToastError).toHaveBeenCalledWith(expect.any(String)) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx index eb8cceb3e5..02369131f7 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx @@ -8,7 +8,7 @@ import { useAppForm } from '@/app/components/base/form' import BaseField from '@/app/components/base/form/form-scenarios/base/field' import { generateZodSchema } from '@/app/components/base/form/form-scenarios/base/utils' import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useConfigurations, useInitialData } from '@/app/components/rag-pipeline/hooks/use-input-fields' import { CrawlStep } from '@/models/datasets' import { cn } from '@/utils/classnames' @@ -44,10 +44,7 @@ const Options = ({ const issues = result.error.issues const firstIssue = issues[0] const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) return errorMessage } return undefined diff --git a/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx b/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx index 2b30c79022..d464041d13 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx @@ -1,10 +1,10 @@ import type { Step } from './step-indicator' import { RiArrowLeftLine } from '@remixicon/react' -import Link from 'next/link' -import { useParams } from 'next/navigation' import * as React from 'react' import Button from '@/app/components/base/button' import Effect from '@/app/components/base/effect' +import Link from '@/next/link' +import { useParams } from '@/next/navigation' import StepIndicator from './step-indicator' type LeftHeaderProps = { diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx index 947313cda5..1e094fedb0 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx @@ -1,13 +1,24 @@ import type { NotionPage } from '@/models/common' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import Toast from '@/app/components/base/toast' import OnlineDocumentPreview from '../online-document-preview' // Uses global react-i18next mock from web/vitest.setup.ts -// Spy on Toast.notify -const toastNotifySpy = vi.spyOn(Toast, 'notify') +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) // Mock dataset-detail context - needs mock to control return values const mockPipelineId = vi.fn() @@ -56,6 +67,7 @@ const defaultProps = { describe('OnlineDocumentPreview', () => { beforeEach(() => { vi.clearAllMocks() + mockToastError.mockReset() mockPipelineId.mockReturnValue('pipeline-123') mockUsePreviewOnlineDocument.mockReturnValue({ mutateAsync: mockMutateAsync, @@ -258,10 +270,7 @@ describe('OnlineDocumentPreview', () => { render() await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: errorMessage, - }) + expect(mockToastError).toHaveBeenCalledWith(errorMessage) }) }) @@ -276,10 +285,7 @@ describe('OnlineDocumentPreview', () => { render() await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'Network Error', - }) + expect(mockToastError).toHaveBeenCalledWith('Network Error') }) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx index 1e3019d427..ff2f9f46a4 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx @@ -6,7 +6,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { Notion } from '@/app/components/base/icons/src/public/common' import { Markdown } from '@/app/components/base/markdown' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { usePreviewOnlineDocument } from '@/service/use-pipeline' import { formatNumberAbbreviated } from '@/utils/format' @@ -44,10 +44,7 @@ const OnlineDocumentPreview = ({ setContent(data.content) }, onError(error) { - Toast.notify({ - type: 'error', - message: error.message, - }) + toast.error(error.message) }, }) }, [currentPage.page_id]) diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx index c82b5a8468..ff5f8afa66 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx @@ -3,13 +3,24 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import * as z from 'zod' import { BaseFieldType } from '@/app/components/base/form/form-scenarios/base/types' -import Toast from '@/app/components/base/toast' import Actions from '../actions' import Form from '../form' import Header from '../header' -// Spy on Toast.notify for validation tests -const toastNotifySpy = vi.spyOn(Toast, 'notify') +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) // Test Data Factory Functions @@ -335,7 +346,7 @@ describe('Form', () => { beforeEach(() => { vi.clearAllMocks() - toastNotifySpy.mockClear() + mockToastError.mockReset() }) describe('Rendering', () => { @@ -444,10 +455,7 @@ describe('Form', () => { // Assert - validation error should be shown await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: '"field1" is required', - }) + expect(mockToastError).toHaveBeenCalledWith('"field1" is required') }) }) }) @@ -566,10 +574,7 @@ describe('Form', () => { fireEvent.submit(form) await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: '"field1" is required', - }) + expect(mockToastError).toHaveBeenCalledWith('"field1" is required') }) }) @@ -583,7 +588,7 @@ describe('Form', () => { // Assert - wait a bit and verify onSubmit was not called await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalled() + expect(mockToastError).toHaveBeenCalled() }) expect(onSubmit).not.toHaveBeenCalled() }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx index 25ac817284..09f28fc5da 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx @@ -2,10 +2,23 @@ import type { BaseConfiguration } from '@/app/components/base/form/form-scenario import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { z } from 'zod' -import Toast from '@/app/components/base/toast' - import Form from '../form' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + error: mockToastError, + }, + } +}) + // Mock the Header component (sibling component, not a base component) vi.mock('../header', () => ({ default: ({ onReset, resetDisabled, onPreview, previewDisabled }: { @@ -44,7 +57,7 @@ const defaultProps = { describe('Form (process-documents)', () => { beforeEach(() => { vi.clearAllMocks() - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + mockToastError.mockReset() }) // Verify basic rendering of form structure @@ -106,9 +119,7 @@ describe('Form (process-documents)', () => { fireEvent.submit(form) await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastError).toHaveBeenCalledWith('"name" Name is required') }) }) @@ -121,7 +132,7 @@ describe('Form (process-documents)', () => { await waitFor(() => { expect(defaultProps.onSubmit).toHaveBeenCalled() }) - expect(Toast.notify).not.toHaveBeenCalled() + expect(mockToastError).not.toHaveBeenCalled() }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx index 4873931e8d..33703d56b2 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx @@ -3,7 +3,7 @@ import type { BaseConfiguration } from '@/app/components/base/form/form-scenario import { useCallback, useImperativeHandle } from 'react' import { useAppForm } from '@/app/components/base/form' import BaseField from '@/app/components/base/form/form-scenarios/base/field' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Header from './header' type OptionsProps = { @@ -34,10 +34,7 @@ const Form = ({ const issues = result.error.issues const firstIssue = issues[0] const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) return errorMessage } return undefined diff --git a/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/__tests__/index.spec.tsx index aa107b8635..f59f5c091b 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/__tests__/index.spec.tsx @@ -10,14 +10,14 @@ import { RETRIEVE_METHOD } from '@/types/app' import EmbeddingProcess from '../index' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), })) // Mock next/link -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: function MockLink({ children, href, ...props }: { children: React.ReactNode, href: string }) { return {children} }, diff --git a/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/index.tsx index a7834fc656..099c3018cd 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/processing/embedding-process/index.tsx @@ -10,8 +10,6 @@ import { RiLoader2Fill, RiTerminalBoxLine, } from '@remixicon/react' -import Link from 'next/link' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -26,6 +24,8 @@ import DocumentFileIcon from '@/app/components/datasets/common/document-file-ico import { useProviderContext } from '@/context/provider-context' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' import { DatasourceType } from '@/models/pipeline' +import Link from '@/next/link' +import { useRouter } from '@/next/navigation' import { useIndexingStatusBatch, useProcessRule } from '@/service/knowledge/use-dataset' import { useInvalidDocumentList } from '@/service/knowledge/use-document' import { cn } from '@/utils/classnames' diff --git a/web/app/components/datasets/documents/create-from-pipeline/steps/__tests__/step-one-content.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/steps/__tests__/step-one-content.spec.tsx index ff0c1b125c..2e121dbbd1 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/steps/__tests__/step-one-content.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/steps/__tests__/step-one-content.spec.tsx @@ -143,7 +143,7 @@ vi.mock('@/service/base', () => ({ upload: vi.fn().mockResolvedValue({ id: 'uploaded-file-id' }), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ datasetId: 'mock-dataset-id' }), useRouter: () => ({ push: vi.fn() }), usePathname: () => '/datasets/mock-dataset-id', diff --git a/web/app/components/datasets/documents/detail/__tests__/document-title.spec.tsx b/web/app/components/datasets/documents/detail/__tests__/document-title.spec.tsx index e7945fc409..3eb1017b8d 100644 --- a/web/app/components/datasets/documents/detail/__tests__/document-title.spec.tsx +++ b/web/app/components/datasets/documents/detail/__tests__/document-title.spec.tsx @@ -5,7 +5,7 @@ import { ChunkingMode } from '@/models/datasets' import { DocumentTitle } from '../document-title' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx index f01a64e34e..be4d2304bd 100644 --- a/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/__tests__/index.spec.tsx @@ -25,7 +25,7 @@ const mocks = vi.hoisted(() => { }) // --- External mocks --- -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mocks.push }), useSearchParams: () => new URLSearchParams(mocks.state.searchParams), })) diff --git a/web/app/components/datasets/documents/detail/__tests__/new-segment.spec.tsx b/web/app/components/datasets/documents/detail/__tests__/new-segment.spec.tsx index 73082108a0..f243f85f29 100644 --- a/web/app/components/datasets/documents/detail/__tests__/new-segment.spec.tsx +++ b/web/app/components/datasets/documents/detail/__tests__/new-segment.spec.tsx @@ -1,26 +1,20 @@ -import type * as React from 'react' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { ChunkingMode } from '@/models/datasets' import { IndexingType } from '../../../create/step-two' import NewSegmentModal from '../new-segment' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ datasetId: 'test-dataset-id', documentId: 'test-document-id', }), })) -const mockNotify = vi.fn() -vi.mock('use-context-selector', async (importOriginal) => { - const actual = await importOriginal() as Record - return { - ...actual, - useContext: () => ({ notify: mockNotify }), - } -}) +const toastErrorSpy = vi.spyOn(toast, 'error') +const toastSuccessSpy = vi.spyOn(toast, 'success') // Mock dataset detail context let mockIndexingTechnique = IndexingType.QUALIFIED @@ -51,11 +45,6 @@ vi.mock('@/service/knowledge/use-segment', () => ({ }), })) -// Mock app store -vi.mock('@/app/components/app/store', () => ({ - useStore: () => ({ appSidebarExpand: 'expand' }), -})) - vi.mock('../completed/common/action-buttons', () => ({ default: ({ handleCancel, handleSave, loading, actionType }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string }) => (
    @@ -139,6 +128,8 @@ vi.mock('@/app/components/datasets/common/image-uploader/image-uploader-in-chunk describe('NewSegmentModal', () => { beforeEach(() => { vi.clearAllMocks() + vi.useRealTimers() + toast.dismiss() mockFullScreen = false mockIndexingTechnique = IndexingType.QUALIFIED }) @@ -258,11 +249,7 @@ describe('NewSegmentModal', () => { fireEvent.click(screen.getByTestId('save-btn')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(toastErrorSpy).toHaveBeenCalledTimes(1) }) }) @@ -272,11 +259,7 @@ describe('NewSegmentModal', () => { fireEvent.click(screen.getByTestId('save-btn')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(toastErrorSpy).toHaveBeenCalledTimes(1) }) }) @@ -287,11 +270,7 @@ describe('NewSegmentModal', () => { fireEvent.click(screen.getByTestId('save-btn')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(toastErrorSpy).toHaveBeenCalledTimes(1) }) }) }) @@ -337,11 +316,7 @@ describe('NewSegmentModal', () => { fireEvent.click(screen.getByTestId('save-btn')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'success', - }), - ) + expect(toastSuccessSpy).toHaveBeenCalledTimes(1) }) }) }) @@ -430,10 +405,9 @@ describe('NewSegmentModal', () => { }) }) - describe('CustomButton in success notification', () => { - it('should call viewNewlyAddedChunk when custom button is clicked', async () => { + describe('Action button in success notification', () => { + it('should call viewNewlyAddedChunk when the toast action is clicked', async () => { const mockViewNewlyAddedChunk = vi.fn() - mockNotify.mockImplementation(() => {}) mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => { options.onSuccess() @@ -442,37 +416,25 @@ describe('NewSegmentModal', () => { }) render( - , + <> + + + , ) - // Enter content and save fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } }) fireEvent.click(screen.getByTestId('save-btn')) + const actionButton = await screen.findByRole('button', { name: 'common.operation.view' }) + fireEvent.click(actionButton) + await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'success', - customComponent: expect.anything(), - }), - ) + expect(mockViewNewlyAddedChunk).toHaveBeenCalledTimes(1) }) - - // Extract customComponent from the notify call args - const notifyCallArgs = mockNotify.mock.calls[0][0] as { customComponent?: React.ReactElement } - expect(notifyCallArgs.customComponent).toBeDefined() - const customComponent = notifyCallArgs.customComponent! - const { container: btnContainer } = render(customComponent) - const viewButton = btnContainer.querySelector('.system-xs-semibold.text-text-accent') as HTMLElement - expect(viewButton).toBeInTheDocument() - fireEvent.click(viewButton) - - // Assert that viewNewlyAddedChunk was called via the onClick handler (lines 66-67) - expect(mockViewNewlyAddedChunk).toHaveBeenCalled() }) }) @@ -599,9 +561,8 @@ describe('NewSegmentModal', () => { }) }) - describe('onSave delayed call', () => { - it('should call onSave after timeout in success handler', async () => { - vi.useFakeTimers() + describe('onSave after success', () => { + it('should call onSave immediately after save succeeds', async () => { const mockOnSave = vi.fn() mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => { options.onSuccess() @@ -611,15 +572,12 @@ describe('NewSegmentModal', () => { render() - // Enter content and save fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } }) fireEvent.click(screen.getByTestId('save-btn')) - // Fast-forward timer - vi.advanceTimersByTime(3000) - - expect(mockOnSave).toHaveBeenCalled() - vi.useRealTimers() + await waitFor(() => { + expect(mockOnSave).toHaveBeenCalledTimes(1) + }) }) }) diff --git a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx index 59ecbf5f25..2a68e6f627 100644 --- a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx @@ -49,7 +49,7 @@ const { mockOnDelete: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => '/datasets/test-dataset-id/documents/test-document-id', })) diff --git a/web/app/components/datasets/documents/detail/completed/__tests__/new-child-segment.spec.tsx b/web/app/components/datasets/documents/detail/completed/__tests__/new-child-segment.spec.tsx index 1b26a15b65..150d399a5d 100644 --- a/web/app/components/datasets/documents/detail/completed/__tests__/new-child-segment.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/__tests__/new-child-segment.spec.tsx @@ -1,23 +1,18 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import NewChildSegmentModal from '../new-child-segment' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({ datasetId: 'test-dataset-id', documentId: 'test-document-id', }), })) -const mockNotify = vi.fn() -vi.mock('use-context-selector', async (importOriginal) => { - const actual = await importOriginal() as Record - return { - ...actual, - useContext: () => ({ notify: mockNotify }), - } -}) +const toastErrorSpy = vi.spyOn(toast, 'error') +const toastSuccessSpy = vi.spyOn(toast, 'success') // Mock document context let mockParentMode = 'paragraph' @@ -48,11 +43,6 @@ vi.mock('@/service/knowledge/use-segment', () => ({ }), })) -// Mock app store -vi.mock('@/app/components/app/store', () => ({ - useStore: () => ({ appSidebarExpand: 'expand' }), -})) - vi.mock('../common/action-buttons', () => ({ default: ({ handleCancel, handleSave, loading, actionType, isChildChunk }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string, isChildChunk?: boolean }) => (
    @@ -103,6 +93,8 @@ vi.mock('../common/segment-index-tag', () => ({ describe('NewChildSegmentModal', () => { beforeEach(() => { vi.clearAllMocks() + vi.useRealTimers() + toast.dismiss() mockFullScreen = false mockParentMode = 'paragraph' }) @@ -198,11 +190,7 @@ describe('NewChildSegmentModal', () => { fireEvent.click(screen.getByTestId('save-btn')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(toastErrorSpy).toHaveBeenCalledTimes(1) }) }) }) @@ -253,11 +241,7 @@ describe('NewChildSegmentModal', () => { fireEvent.click(screen.getByTestId('save-btn')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'success', - }), - ) + expect(toastSuccessSpy).toHaveBeenCalledTimes(1) }) }) }) @@ -374,35 +358,62 @@ describe('NewChildSegmentModal', () => { // View newly added chunk describe('View Newly Added Chunk', () => { - it('should show custom button in full-doc mode after save', async () => { + it('should call viewNewlyAddedChildChunk when the toast action is clicked', async () => { mockParentMode = 'full-doc' + const mockViewNewlyAddedChildChunk = vi.fn() mockAddChildSegment.mockImplementation((_params, options) => { options.onSuccess({ data: { id: 'new-child-id' } }) options.onSettled() return Promise.resolve() }) - render() + render( + <> + + + , + ) - // Enter valid content fireEvent.change(screen.getByTestId('content-input'), { target: { value: 'Valid content' }, }) fireEvent.click(screen.getByTestId('save-btn')) - // Assert - success notification with custom component + const actionButton = await screen.findByRole('button', { name: 'common.operation.view' }) + fireEvent.click(actionButton) + await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'success', - customComponent: expect.anything(), - }), - ) + expect(mockViewNewlyAddedChildChunk).toHaveBeenCalledTimes(1) }) }) - it('should not show custom button in paragraph mode after save', async () => { + it('should call onSave immediately in full-doc mode after save succeeds', async () => { + mockParentMode = 'full-doc' + const mockOnSave = vi.fn() + mockAddChildSegment.mockImplementation((_params, options) => { + options.onSuccess({ data: { id: 'new-child-id' } }) + options.onSettled() + return Promise.resolve() + }) + + render() + + fireEvent.change(screen.getByTestId('content-input'), { + target: { value: 'Valid content' }, + }) + + fireEvent.click(screen.getByTestId('save-btn')) + + await waitFor(() => { + expect(mockOnSave).toHaveBeenCalledTimes(1) + }) + }) + + it('should call onSave with the new child chunk in paragraph mode', async () => { mockParentMode = 'paragraph' const mockOnSave = vi.fn() mockAddChildSegment.mockImplementation((_params, options) => { diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts index f54c00e3e7..6e9239c972 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts @@ -68,7 +68,7 @@ const { mockPathname: { current: '/datasets/test/documents/test' }, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname.current, })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts index aa91e9f464..8948f6b547 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts @@ -1,12 +1,12 @@ import type { FileEntity } from '@/app/components/datasets/common/image-uploader/types' import type { SegmentDetailModel, SegmentsResponse, SegmentUpdater } from '@/models/datasets' import { useQueryClient } from '@tanstack/react-query' -import { usePathname } from 'next/navigation' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { ChunkingMode } from '@/models/datasets' +import { usePathname } from '@/next/navigation' import { useChunkListAllKey, useChunkListDisabledKey, diff --git a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx index e28fb774fb..2766754f7d 100644 --- a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx +++ b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx @@ -1,15 +1,12 @@ import type { FC } from 'react' import type { ChildChunkDetail, SegmentUpdater } from '@/models/datasets' import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react' -import { useParams } from 'next/navigation' -import { memo, useMemo, useRef, useState } from 'react' +import { memo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import { useShallow } from 'zustand/react/shallow' -import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { ChunkingMode } from '@/models/datasets' +import { useParams } from '@/next/navigation' import { useAddChildSegment } from '@/service/knowledge/use-segment' import { cn } from '@/utils/classnames' import { formatNumber } from '@/utils/format' @@ -35,39 +32,15 @@ const NewChildSegmentModal: FC = ({ viewNewlyAddedChildChunk, }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [content, setContent] = useState('') const { datasetId, documentId } = useParams<{ datasetId: string, documentId: string }>() const [loading, setLoading] = useState(false) const [addAnother, setAddAnother] = useState(true) const fullScreen = useSegmentListContext(s => s.fullScreen) const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) - const { appSidebarExpand } = useAppStore(useShallow(state => ({ - appSidebarExpand: state.appSidebarExpand, - }))) const parentMode = useDocumentContext(s => s.parentMode) - const refreshTimer = useRef(null) - - const isFullDocMode = useMemo(() => { - return parentMode === 'full-doc' - }, [parentMode]) - - const CustomButton = ( - <> - - - - ) + const isFullDocMode = parentMode === 'full-doc' const handleCancel = (actionType: 'esc' | 'add' = 'esc') => { if (actionType === 'esc' || !addAnother) @@ -80,26 +53,25 @@ const NewChildSegmentModal: FC = ({ const params: SegmentUpdater = { content: '' } if (!content.trim()) - return notify({ type: 'error', message: t('segment.contentEmpty', { ns: 'datasetDocuments' }) }) + return toast.error(t('segment.contentEmpty', { ns: 'datasetDocuments' })) params.content = content setLoading(true) await addChildSegment({ datasetId, documentId, segmentId: chunkId, body: params }, { onSuccess(res) { - notify({ - type: 'success', - message: t('segment.childChunkAdded', { ns: 'datasetDocuments' }), - className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'} - !top-auto !right-auto !mb-[52px] !ml-11`, - customComponent: isFullDocMode && CustomButton, + toast.success(t('segment.childChunkAdded', { ns: 'datasetDocuments' }), { + actionProps: isFullDocMode + ? { + children: t('operation.view', { ns: 'common' }), + onClick: viewNewlyAddedChildChunk, + } + : undefined, }) handleCancel('add') setContent('') if (isFullDocMode) { - refreshTimer.current = setTimeout(() => { - onSave() - }, 3000) + onSave() } else { onSave(res.data) @@ -111,10 +83,8 @@ const NewChildSegmentModal: FC = ({ }) } - const wordCountText = useMemo(() => { - const count = content.length - return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}` - }, [content.length]) + const count = content.length + const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}` return (
    diff --git a/web/app/components/datasets/documents/detail/document-title.tsx b/web/app/components/datasets/documents/detail/document-title.tsx index ec44e3ea97..2190338ab2 100644 --- a/web/app/components/datasets/documents/detail/document-title.tsx +++ b/web/app/components/datasets/documents/detail/document-title.tsx @@ -1,6 +1,6 @@ import type { FC } from 'react' import type { ChunkingMode, ParentMode } from '@/models/datasets' -import { useRouter } from 'next/navigation' +import { useRouter } from '@/next/navigation' import { cn } from '@/utils/classnames' import DocumentPicker from '../../common/document-picker' diff --git a/web/app/components/datasets/documents/detail/index.tsx b/web/app/components/datasets/documents/detail/index.tsx index b6842605c6..891c177169 100644 --- a/web/app/components/datasets/documents/detail/index.tsx +++ b/web/app/components/datasets/documents/detail/index.tsx @@ -1,7 +1,6 @@ 'use client' import type { FC } from 'react' import type { DataSourceInfo, FileItem, FullDocumentDetail, LegacyDataSourceInfo } from '@/models/datasets' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -13,6 +12,7 @@ import Metadata from '@/app/components/datasets/metadata/metadata-document' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { ChunkingMode } from '@/models/datasets' +import { useRouter, useSearchParams } from '@/next/navigation' import { useDocumentDetail, useDocumentMetadata, useInvalidDocumentList } from '@/service/knowledge/use-document' import { useCheckSegmentBatchImportProgress, useChildSegmentListKey, useSegmentBatchImport, useSegmentListKey } from '@/service/knowledge/use-segment' import { useInvalid } from '@/service/use-base' diff --git a/web/app/components/datasets/documents/detail/new-segment.tsx b/web/app/components/datasets/documents/detail/new-segment.tsx index d2e27e9969..bbc8a3b16b 100644 --- a/web/app/components/datasets/documents/detail/new-segment.tsx +++ b/web/app/components/datasets/documents/detail/new-segment.tsx @@ -2,17 +2,14 @@ import type { FC } from 'react' import type { FileEntity } from '@/app/components/datasets/common/image-uploader/types' import type { SegmentUpdater } from '@/models/datasets' import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react' -import { useParams } from 'next/navigation' -import { memo, useCallback, useMemo, useRef, useState } from 'react' +import { memo, useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import { useShallow } from 'zustand/react/shallow' -import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { ChunkingMode } from '@/models/datasets' +import { useParams } from '@/next/navigation' import { useAddSegment } from '@/service/knowledge/use-segment' import { cn } from '@/utils/classnames' import { formatNumber } from '@/utils/format' @@ -39,7 +36,6 @@ const NewSegmentModal: FC = ({ viewNewlyAddedChunk, }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [question, setQuestion] = useState('') const [answer, setAnswer] = useState('') const [attachments, setAttachments] = useState([]) @@ -50,27 +46,7 @@ const NewSegmentModal: FC = ({ const fullScreen = useSegmentListContext(s => s.fullScreen) const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique) - const { appSidebarExpand } = useAppStore(useShallow(state => ({ - appSidebarExpand: state.appSidebarExpand, - }))) - const [imageUploaderKey, setImageUploaderKey] = useState(Date.now()) - const refreshTimer = useRef(null) - - const CustomButton = useMemo(() => ( - <> - - - - ), [viewNewlyAddedChunk, t]) + const [imageUploaderKey, setImageUploaderKey] = useState(() => Date.now()) const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => { if (actionType === 'esc' || !addAnother) @@ -87,16 +63,10 @@ const NewSegmentModal: FC = ({ const params: SegmentUpdater = { content: '', attachment_ids: [] } if (docForm === ChunkingMode.qa) { if (!question.trim()) { - return notify({ - type: 'error', - message: t('segment.questionEmpty', { ns: 'datasetDocuments' }), - }) + return toast.error(t('segment.questionEmpty', { ns: 'datasetDocuments' })) } if (!answer.trim()) { - return notify({ - type: 'error', - message: t('segment.answerEmpty', { ns: 'datasetDocuments' }), - }) + return toast.error(t('segment.answerEmpty', { ns: 'datasetDocuments' })) } params.content = question @@ -104,10 +74,7 @@ const NewSegmentModal: FC = ({ } else { if (!question.trim()) { - return notify({ - type: 'error', - message: t('segment.contentEmpty', { ns: 'datasetDocuments' }), - }) + return toast.error(t('segment.contentEmpty', { ns: 'datasetDocuments' })) } params.content = question @@ -122,12 +89,11 @@ const NewSegmentModal: FC = ({ setLoading(true) await addSegment({ datasetId, documentId, body: params }, { onSuccess() { - notify({ - type: 'success', - message: t('segment.chunkAdded', { ns: 'datasetDocuments' }), - className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'} - !top-auto !right-auto !mb-[52px] !ml-11`, - customComponent: CustomButton, + toast.success(t('segment.chunkAdded', { ns: 'datasetDocuments' }), { + actionProps: { + children: t('operation.view', { ns: 'common' }), + onClick: viewNewlyAddedChunk, + }, }) handleCancel('add') setQuestion('') @@ -135,20 +101,16 @@ const NewSegmentModal: FC = ({ setAttachments([]) setImageUploaderKey(Date.now()) setKeywords([]) - refreshTimer.current = setTimeout(() => { - onSave() - }, 3000) + onSave() }, onSettled() { setLoading(false) }, }) - }, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave]) + }, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, t, handleCancel, onSave, viewNewlyAddedChunk]) - const wordCountText = useMemo(() => { - const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length - return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}` - }, [question.length, answer.length, docForm, t]) + const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length + const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}` const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL diff --git a/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx b/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx index 84534298c9..bf516d432b 100644 --- a/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx +++ b/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx @@ -5,7 +5,7 @@ import DocumentSettings from '../document-settings' const mockPush = vi.fn() const mockBack = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, back: mockBack, @@ -224,6 +224,20 @@ describe('DocumentSettings', () => { // Data source types describe('Data Source Types', () => { + it('should handle upload_file_id data source format', () => { + mockDocumentDetail = { + name: 'test-document', + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: '4a807f05-45d6-4fc4-b7a8-b009a4568b36', + }, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + it('should handle legacy upload_file data source', () => { mockDocumentDetail = { name: 'test-document', @@ -307,6 +321,18 @@ describe('DocumentSettings', () => { expect(screen.getByTestId('files-count')).toHaveTextContent('0') }) + it('should handle empty data_source_info object', () => { + mockDocumentDetail = { + name: 'test-document', + data_source_type: 'upload_file', + data_source_info: {}, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('0') + }) + it('should maintain structure when rerendered', () => { const { rerender } = render( , @@ -317,4 +343,37 @@ describe('DocumentSettings', () => { expect(screen.getByTestId('step-two')).toBeInTheDocument() }) }) + + describe('Files Extraction Regression Tests', () => { + it('should correctly extract file ID from upload_file_id format', () => { + const fileId = '4a807f05-45d6-4fc4-b7a8-b009a4568b36' + mockDocumentDetail = { + name: 'test-document.pdf', + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: fileId, + }, + } + + render() + + // Verify files array is populated with correct file ID + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + + it('should preserve document name when using upload_file_id format', () => { + const documentName = 'my-uploaded-document.txt' + mockDocumentDetail = { + name: documentName, + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: 'some-file-id', + }, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + }) }) diff --git a/web/app/components/datasets/documents/detail/settings/document-settings.tsx b/web/app/components/datasets/documents/detail/settings/document-settings.tsx index 67773cb7d6..2b6cc77683 100644 --- a/web/app/components/datasets/documents/detail/settings/document-settings.tsx +++ b/web/app/components/datasets/documents/detail/settings/document-settings.tsx @@ -8,10 +8,10 @@ import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, + UploadFileIdInfo, WebsiteCrawlInfo, } from '@/models/datasets' import { useBoolean } from 'ahooks' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' @@ -24,6 +24,7 @@ import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/con import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import DatasetDetailContext from '@/context/dataset-detail' +import { useRouter } from '@/next/navigation' import { useDocumentDetail, useInvalidDocumentDetail, useInvalidDocumentList } from '@/service/knowledge/use-document' type DocumentSettingsProps = { @@ -61,6 +62,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const dataSourceInfo = documentDetail?.data_source_info + // Type guards for DataSourceInfo union const isLegacyDataSourceInfo = (info: DataSourceInfo | undefined): info is LegacyDataSourceInfo => { return !!info && 'upload_file' in info } @@ -73,10 +75,15 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const isLocalFileInfo = (info: DataSourceInfo | undefined): info is LocalFileInfo => { return !!info && 'related_id' in info && 'transfer_method' in info } + const isUploadFileIdInfo = (info: DataSourceInfo | undefined): info is UploadFileIdInfo => { + return !!info && 'upload_file_id' in info + } + const legacyInfo = isLegacyDataSourceInfo(dataSourceInfo) ? dataSourceInfo : undefined const websiteInfo = isWebsiteCrawlInfo(dataSourceInfo) ? dataSourceInfo : undefined const onlineDocumentInfo = isOnlineDocumentInfo(dataSourceInfo) ? dataSourceInfo : undefined const localFileInfo = isLocalFileInfo(dataSourceInfo) ? dataSourceInfo : undefined + const uploadFileIdInfo = isUploadFileIdInfo(dataSourceInfo) ? dataSourceInfo : undefined const currentPage = useMemo(() => { if (legacyInfo) { @@ -101,8 +108,20 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { }, [documentDetail?.data_source_type, documentDetail?.name, legacyInfo, onlineDocumentInfo]) const files = useMemo(() => { - if (legacyInfo?.upload_file) - return [legacyInfo.upload_file as CustomFile] + // Handle upload_file_id format + if (uploadFileIdInfo) { + return [{ + id: uploadFileIdInfo.upload_file_id, + name: documentDetail?.name || '', + } as unknown as CustomFile] + } + + // Handle legacy upload_file format + if (legacyInfo?.upload_file) { + return [legacyInfo.upload_file as unknown as CustomFile] + } + + // Handle local file info format if (localFileInfo) { const { related_id, name, extension } = localFileInfo return [{ @@ -111,8 +130,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { extension, } as unknown as CustomFile] } + return [] - }, [legacyInfo?.upload_file, localFileInfo]) + }, [uploadFileIdInfo, legacyInfo?.upload_file, localFileInfo, documentDetail?.name]) const websitePages = useMemo(() => { if (!websiteInfo) diff --git a/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/index.spec.tsx index 9f2ccc0acd..764667c55c 100644 --- a/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ import PipelineSettings from '../index' // Mock Next.js router const mockPush = vi.fn() const mockBack = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, back: mockBack, diff --git a/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/left-header.spec.tsx b/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/left-header.spec.tsx index 9a1ffab673..30019ca67d 100644 --- a/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/left-header.spec.tsx +++ b/web/app/components/datasets/documents/detail/settings/pipeline-settings/__tests__/left-header.spec.tsx @@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import LeftHeader from '../left-header' const mockBack = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ back: mockBack, }), diff --git a/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx b/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx index 08e13765e5..4c9dd641e3 100644 --- a/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx +++ b/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx @@ -2,13 +2,13 @@ import type { NotionPage } from '@/models/common' import type { CrawlResultItem, CustomFile, FileIndexingEstimateResponse } from '@/models/datasets' import type { OnlineDriveFile, PublishedPipelineRunPreviewResponse } from '@/models/pipeline' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { DatasourceType } from '@/models/pipeline' +import { useRouter } from '@/next/navigation' import { useInvalidDocumentDetail, useInvalidDocumentList } from '@/service/knowledge/use-document' import { usePipelineExecutionLog, useRunPublishedPipeline } from '@/service/use-pipeline' import ChunkPreview from '../../../create-from-pipeline/preview/chunk-preview' diff --git a/web/app/components/datasets/documents/detail/settings/pipeline-settings/left-header.tsx b/web/app/components/datasets/documents/detail/settings/pipeline-settings/left-header.tsx index 280d835586..15b06a5f10 100644 --- a/web/app/components/datasets/documents/detail/settings/pipeline-settings/left-header.tsx +++ b/web/app/components/datasets/documents/detail/settings/pipeline-settings/left-header.tsx @@ -1,10 +1,10 @@ import { RiArrowLeftLine } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Effect from '@/app/components/base/effect' +import { useRouter } from '@/next/navigation' type LeftHeaderProps = { title: string diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx index 764b04227c..29d9c01f71 100644 --- a/web/app/components/datasets/documents/index.tsx +++ b/web/app/components/datasets/documents/index.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' -import { useRouter } from 'next/navigation' import { useCallback } from 'react' import Loading from '@/app/components/base/loading' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import { DataSourceType } from '@/models/datasets' +import { useRouter } from '@/next/navigation' import { useDocumentList, useInvalidDocumentDetail, useInvalidDocumentList } from '@/service/knowledge/use-document' import { useChildSegmentListKey, useSegmentListKey } from '@/service/knowledge/use-segment' import { useInvalid } from '@/service/use-base' diff --git a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx index a6a60aa856..0949648fa0 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ import ExternalKnowledgeBaseConnector from '../index' const mockRouterBack = vi.fn() const mockReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ back: mockRouterBack, replace: mockReplace, @@ -21,12 +21,19 @@ vi.mock('@/context/i18n', () => ({ useDocLink: () => (path?: string) => `https://docs.dify.ai/en${path || ''}`, })) -const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ - notify: mockNotify, - }), -})) +const mockToastSuccess = vi.hoisted(() => vi.fn()) +const mockToastError = vi.hoisted(() => vi.fn()) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + error: mockToastError, + }, + } +}) // Mock modal context vi.mock('@/context/modal-context', () => ({ @@ -162,10 +169,7 @@ describe('ExternalKnowledgeBaseConnector', () => { }) // Verify success notification - expect(mockNotify).toHaveBeenCalledWith({ - type: 'success', - message: 'External Knowledge Base Connected Successfully', - }) + expect(mockToastSuccess).toHaveBeenCalledWith('dataset.externalKnowledgeForm.connectedSuccess') // Verify navigation back expect(mockRouterBack).toHaveBeenCalledTimes(1) @@ -204,10 +208,7 @@ describe('ExternalKnowledgeBaseConnector', () => { // Verify error notification await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'Failed to connect External Knowledge Base', - }) + expect(mockToastError).toHaveBeenCalledWith('dataset.externalKnowledgeForm.connectedFailed') }) // Verify no navigation @@ -226,10 +227,7 @@ describe('ExternalKnowledgeBaseConnector', () => { await fillFormAndSubmit(user) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'Failed to connect External Knowledge Base', - }) + expect(mockToastError).toHaveBeenCalledWith('dataset.externalKnowledgeForm.connectedFailed') }) expect(mockRouterBack).not.toHaveBeenCalled() @@ -272,10 +270,7 @@ describe('ExternalKnowledgeBaseConnector', () => { resolvePromise({ id: 'new-id' }) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith({ - type: 'success', - message: 'External Knowledge Base Connected Successfully', - }) + expect(mockToastSuccess).toHaveBeenCalledWith('dataset.externalKnowledgeForm.connectedSuccess') }) }) }) diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx index cf36eed382..85fc254cfc 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -1,25 +1,26 @@ 'use client' import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-knowledge-base/create/declarations' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' +import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' -import { useToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' +import { useRouter } from '@/next/navigation' import { createExternalKnowledgeBase } from '@/service/datasets' const ExternalKnowledgeBaseConnector = () => { - const { notify } = useToastContext() const [loading, setLoading] = useState(false) const router = useRouter() + const { t } = useTranslation() const handleConnect = async (formValue: CreateKnowledgeBaseReq) => { try { setLoading(true) const result = await createExternalKnowledgeBase({ body: formValue }) if (result && result.id) { - notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' }) + toast.success(t('externalKnowledgeForm.connectedSuccess', { ns: 'dataset' })) trackEvent('create_external_knowledge_base', { provider: formValue.provider, name: formValue.name, @@ -30,7 +31,7 @@ const ExternalKnowledgeBaseConnector = () => { } catch (error) { console.error('Error creating external knowledge base:', error) - notify({ type: 'error', message: 'Failed to connect External Knowledge Base' }) + toast.error(t('externalKnowledgeForm.connectedFailed', { ns: 'dataset' })) } setLoading(false) } diff --git a/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx index f84e6c57c1..a527da982a 100644 --- a/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx @@ -2,13 +2,13 @@ import { RiAddLine, RiArrowDownSLine, } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' import { useExternalKnowledgeApi } from '@/context/external-knowledge-api-context' import { useModalContext } from '@/context/modal-context' +import { useRouter } from '@/next/navigation' type ApiItem = { value: string diff --git a/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx index 75b9e8de9c..4652a8a5f1 100644 --- a/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx @@ -1,7 +1,6 @@ 'use client' import { RiAddLine } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -9,6 +8,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import { useExternalKnowledgeApi } from '@/context/external-knowledge-api-context' import { useModalContext } from '@/context/modal-context' +import { useRouter } from '@/next/navigation' import ExternalApiSelect from './ExternalApiSelect' type ExternalApiSelectionProps = { diff --git a/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelect.spec.tsx b/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelect.spec.tsx index 3b8b35a5b7..7af75fbcdd 100644 --- a/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelect.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelect.spec.tsx @@ -12,7 +12,7 @@ const mocks = vi.hoisted(() => ({ mutateExternalKnowledgeApis: vi.fn(), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mocks.push, refresh: mocks.refresh }), })) diff --git a/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelection.spec.tsx b/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelection.spec.tsx index 702890bee9..8d055606b8 100644 --- a/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelection.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/__tests__/ExternalApiSelection.spec.tsx @@ -10,7 +10,7 @@ const mocks = vi.hoisted(() => ({ externalKnowledgeApiList: [] as Array<{ id: string, name: string, settings: { endpoint: string } }>, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mocks.push, refresh: mocks.refresh }), })) @@ -35,7 +35,7 @@ vi.mock('../ExternalApiSelect', () => ({ {value} {items.length} {items.map((item: MockSelectItem) => ( - ))} diff --git a/web/app/components/datasets/external-knowledge-base/create/__tests__/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/create/__tests__/index.spec.tsx index 213fe30ee3..a3282e441c 100644 --- a/web/app/components/datasets/external-knowledge-base/create/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/__tests__/index.spec.tsx @@ -7,7 +7,7 @@ import RetrievalSettings from '../RetrievalSettings' const mockReplace = vi.fn() const mockRefresh = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, push: vi.fn(), diff --git a/web/app/components/datasets/external-knowledge-base/create/index.tsx b/web/app/components/datasets/external-knowledge-base/create/index.tsx index 07b6e71fa6..0e855259ba 100644 --- a/web/app/components/datasets/external-knowledge-base/create/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/index.tsx @@ -2,12 +2,12 @@ import type { CreateKnowledgeBaseReq } from './declarations' import { RiArrowLeftLine, RiArrowRightLine } from '@remixicon/react' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useDocLink } from '@/context/i18n' +import { useRouter } from '@/next/navigation' import ExternalApiSelection from './ExternalApiSelection' import InfoPanel from './InfoPanel' import KnowledgeBaseInfo from './KnowledgeBaseInfo' diff --git a/web/app/components/datasets/extra-info/__tests__/index.spec.tsx b/web/app/components/datasets/extra-info/__tests__/index.spec.tsx index 4a8d89e9fb..de61894a11 100644 --- a/web/app/components/datasets/extra-info/__tests__/index.spec.tsx +++ b/web/app/components/datasets/extra-info/__tests__/index.spec.tsx @@ -13,7 +13,7 @@ import Statistics from '../statistics' // Mock Setup -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn(), replace: vi.fn(), @@ -23,7 +23,7 @@ vi.mock('next/navigation', () => ({ })) // Mock next/link -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, ...props }: { children: React.ReactNode, href: string, [key: string]: unknown }) => ( {children} ), diff --git a/web/app/components/datasets/extra-info/api-access/card.tsx b/web/app/components/datasets/extra-info/api-access/card.tsx index 946536bf2c..eee586ff8e 100644 --- a/web/app/components/datasets/extra-info/api-access/card.tsx +++ b/web/app/components/datasets/extra-info/api-access/card.tsx @@ -1,5 +1,4 @@ import { RiArrowRightUpLine, RiBookOpenLine } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -8,6 +7,7 @@ import Indicator from '@/app/components/header/indicator' import { useSelector as useAppContextSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' +import Link from '@/next/link' import { useDisableDatasetServiceApi, useEnableDatasetServiceApi } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/components/datasets/extra-info/service-api/__tests__/index.spec.tsx b/web/app/components/datasets/extra-info/service-api/__tests__/index.spec.tsx index b94508de6a..8137052383 100644 --- a/web/app/components/datasets/extra-info/service-api/__tests__/index.spec.tsx +++ b/web/app/components/datasets/extra-info/service-api/__tests__/index.spec.tsx @@ -9,7 +9,7 @@ import ServiceApi from '../index' // Mock Setup -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn(), replace: vi.fn(), @@ -19,7 +19,7 @@ vi.mock('next/navigation', () => ({ })) // Mock next/link -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href, ...props }: { children: React.ReactNode, href: string, [key: string]: unknown }) => ( {children} ), diff --git a/web/app/components/datasets/extra-info/service-api/card.tsx b/web/app/components/datasets/extra-info/service-api/card.tsx index 31076d12fc..bf84204ea4 100644 --- a/web/app/components/datasets/extra-info/service-api/card.tsx +++ b/web/app/components/datasets/extra-info/service-api/card.tsx @@ -1,5 +1,4 @@ import { RiBookOpenLine, RiKey2Line } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -9,6 +8,7 @@ import { ApiAggregate } from '@/app/components/base/icons/src/vender/knowledge' import SecretKeyModal from '@/app/components/develop/secret-key/secret-key-modal' import Indicator from '@/app/components/header/indicator' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' +import Link from '@/next/link' type CardProps = { apiBaseUrl: string diff --git a/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx b/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx index fe7510b498..2dda6ecaae 100644 --- a/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/__tests__/index.spec.tsx @@ -27,7 +27,7 @@ vi.mock('@/app/components/datasets/external-knowledge-base/create/RetrievalSetti // Mock Setup -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn(), replace: vi.fn(), diff --git a/web/app/components/datasets/list/__tests__/datasets.spec.tsx b/web/app/components/datasets/list/__tests__/datasets.spec.tsx index 49bda88c8b..5b777e0b2e 100644 --- a/web/app/components/datasets/list/__tests__/datasets.spec.tsx +++ b/web/app/components/datasets/list/__tests__/datasets.spec.tsx @@ -6,7 +6,7 @@ import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datase import { RETRIEVE_METHOD } from '@/types/app' import Datasets from '../datasets' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), })) diff --git a/web/app/components/datasets/list/__tests__/index.spec.tsx b/web/app/components/datasets/list/__tests__/index.spec.tsx index 73e0ba0960..37a787ff51 100644 --- a/web/app/components/datasets/list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/list/__tests__/index.spec.tsx @@ -4,7 +4,7 @@ import List from '../index' const mockPush = vi.fn() const mockReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, replace: mockReplace, diff --git a/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx b/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx index ebe80e4686..21ddda5ce6 100644 --- a/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx +++ b/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx @@ -22,7 +22,7 @@ vi.mock('@/hooks/use-format-time-from-now', () => ({ const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush }), })) diff --git a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts index 63ac30630e..f29d85b460 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts @@ -5,9 +5,15 @@ import { IndexingType } from '@/app/components/datasets/create/step-two' import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' import { useDatasetCardState } from '../use-dataset-card-state' -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, + error: mockToastError, }, })) @@ -299,7 +305,7 @@ describe('useDatasetCardState', () => { describe('Error Handling', () => { it('should show error toast when export pipeline fails', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') mockExportPipeline.mockRejectedValue(new Error('Export failed')) const dataset = createMockDataset({ pipeline_id: 'pipeline-1' }) @@ -311,14 +317,11 @@ describe('useDatasetCardState', () => { await result.current.handleExportPipeline() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should handle Response error in detectIsUsedByApp', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const mockResponse = new Response(JSON.stringify({ message: 'API Error' }), { status: 400, }) @@ -333,14 +336,11 @@ describe('useDatasetCardState', () => { await result.current.detectIsUsedByApp() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.stringContaining('API Error'), - }) + expect(toast.error).toHaveBeenCalledWith(expect.stringContaining('API Error')) }) it('should handle generic Error in detectIsUsedByApp', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') mockCheckUsage.mockRejectedValue(new Error('Network error')) const dataset = createMockDataset() @@ -352,14 +352,11 @@ describe('useDatasetCardState', () => { await result.current.detectIsUsedByApp() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Network error', - }) + expect(toast.error).toHaveBeenCalledWith('Network error') }) it('should handle error without message in detectIsUsedByApp', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') mockCheckUsage.mockRejectedValue({}) const dataset = createMockDataset() @@ -371,10 +368,7 @@ describe('useDatasetCardState', () => { await result.current.detectIsUsedByApp() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Unknown error', - }) + expect(toast.error).toHaveBeenCalledWith('dataset.unknownError') }) it('should handle exporting state correctly', async () => { diff --git a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts index 4bd8357f1c..850eee4364 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts @@ -2,7 +2,7 @@ import type { Tag } from '@/app/components/base/tag-management/constant' import type { DataSet } from '@/models/datasets' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useCheckDatasetUsage, useDeleteDataset } from '@/service/use-dataset-card' import { useExportPipelineDSL } from '@/service/use-pipeline' import { downloadBlob } from '@/utils/download' @@ -70,7 +70,7 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { - Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast.error(t('exportFailed', { ns: 'app' })) } finally { setExporting(false) @@ -93,10 +93,10 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO catch (e: unknown) { if (e instanceof Response) { const res = await e.json() - Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) + toast.error(res?.message || t('unknownError', { ns: 'dataset' })) } else { - Toast.notify({ type: 'error', message: (e as Error)?.message || 'Unknown error' }) + toast.error((e as Error)?.message || t('unknownError', { ns: 'dataset' })) } } }, [dataset.id, checkUsage, t]) @@ -104,7 +104,7 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO const onConfirmDelete = useCallback(async () => { try { await deleteDatasetMutation(dataset.id) - Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) + toast.success(t('datasetDeleted', { ns: 'dataset' })) onSuccess?.() } finally { diff --git a/web/app/components/datasets/list/dataset-card/index.tsx b/web/app/components/datasets/list/dataset-card/index.tsx index 85dba7e8ff..2a22255eda 100644 --- a/web/app/components/datasets/list/dataset-card/index.tsx +++ b/web/app/components/datasets/list/dataset-card/index.tsx @@ -1,9 +1,9 @@ 'use client' import type { DataSet } from '@/models/datasets' import { useHover } from 'ahooks' -import { useRouter } from 'next/navigation' import { useMemo, useRef } from 'react' import { useSelector as useAppContextWithSelector } from '@/context/app-context' +import { useRouter } from '@/next/navigation' import CornerLabels from './components/corner-labels' import DatasetCardFooter from './components/dataset-card-footer' import DatasetCardHeader from './components/dataset-card-header' diff --git a/web/app/components/datasets/list/new-dataset-card/option.tsx b/web/app/components/datasets/list/new-dataset-card/option.tsx index e862b5c11e..05b14fef1a 100644 --- a/web/app/components/datasets/list/new-dataset-card/option.tsx +++ b/web/app/components/datasets/list/new-dataset-card/option.tsx @@ -1,5 +1,5 @@ -import Link from 'next/link' import * as React from 'react' +import Link from '@/next/link' type OptionProps = { Icon: React.ComponentType<{ className?: string }> diff --git a/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx b/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx index e56fe46422..9cc4f89bd8 100644 --- a/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx +++ b/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx @@ -45,7 +45,7 @@ vi.mock('../../hooks/use-check-metadata-name', () => ({ }), })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn(), }), diff --git a/web/app/components/datasets/metadata/metadata-document/__tests__/info-group.spec.tsx b/web/app/components/datasets/metadata/metadata-document/__tests__/info-group.spec.tsx index f30e188cd7..d783b882a8 100644 --- a/web/app/components/datasets/metadata/metadata-document/__tests__/info-group.spec.tsx +++ b/web/app/components/datasets/metadata/metadata-document/__tests__/info-group.spec.tsx @@ -22,7 +22,7 @@ type InputCombinedProps = { type: DataType } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn(), }), diff --git a/web/app/components/datasets/metadata/metadata-document/info-group.tsx b/web/app/components/datasets/metadata/metadata-document/info-group.tsx index 6d172c92f4..0b21d607bd 100644 --- a/web/app/components/datasets/metadata/metadata-document/info-group.tsx +++ b/web/app/components/datasets/metadata/metadata-document/info-group.tsx @@ -2,12 +2,12 @@ import type { FC } from 'react' import type { MetadataItemWithValue } from '../types' import { RiDeleteBinLine, RiQuestionLine } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' import Tooltip from '@/app/components/base/tooltip' import useTimestamp from '@/hooks/use-timestamp' +import { useRouter } from '@/next/navigation' import { cn } from '@/utils/classnames' import AddMetadataButton from '../add-metadata-button' import InputCombined from '../edit-metadata-batch/input-combined' diff --git a/web/app/components/datasets/settings/form/__tests__/index.spec.tsx b/web/app/components/datasets/settings/form/__tests__/index.spec.tsx index 9eeb62b8e8..a3a22e000b 100644 --- a/web/app/components/datasets/settings/form/__tests__/index.spec.tsx +++ b/web/app/components/datasets/settings/form/__tests__/index.spec.tsx @@ -6,6 +6,10 @@ import { RETRIEVE_METHOD } from '@/types/app' import { IndexingType } from '../../../create/step-two' import Form from '../index' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + // Mock contexts const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() @@ -189,9 +193,10 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + success: vi.fn(), }, })) @@ -391,7 +396,7 @@ describe('Form', () => { }) it('should show error when trying to save with empty name', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') render(
    ) // Clear the name @@ -403,10 +408,7 @@ describe('Form', () => { fireEvent.click(saveButton) await waitFor(() => { - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) }) diff --git a/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts b/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts index f27b542b1e..00462619aa 100644 --- a/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts +++ b/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts @@ -6,6 +6,11 @@ import { RETRIEVE_METHOD } from '@/types/app' import { IndexingType } from '../../../../create/step-two' import { useFormState } from '../use-form-state' +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) + // Mock contexts const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() @@ -122,9 +127,10 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, + error: mockToastError, }, })) @@ -423,7 +429,7 @@ describe('useFormState', () => { describe('handleSave', () => { it('should show error toast when name is empty', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -434,14 +440,11 @@ describe('useFormState', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should show error toast when name is whitespace only', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -452,10 +455,7 @@ describe('useFormState', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should call updateDatasetSetting with correct params', async () => { @@ -477,7 +477,7 @@ describe('useFormState', () => { }) it('should show success toast on successful save', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) await act(async () => { @@ -485,10 +485,7 @@ describe('useFormState', () => { }) await waitFor(() => { - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'success', - message: expect.any(String), - }) + expect(toast.success).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -553,7 +550,7 @@ describe('useFormState', () => { it('should show error toast on save failure', async () => { const { updateDatasetSetting } = await import('@/service/datasets') - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') vi.mocked(updateDatasetSetting).mockRejectedValueOnce(new Error('Network error')) const { result } = renderHook(() => useFormState()) @@ -562,10 +559,7 @@ describe('useFormState', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should include partial_member_list when permission is partialMembers', async () => { diff --git a/web/app/components/datasets/settings/form/hooks/use-form-state.ts b/web/app/components/datasets/settings/form/hooks/use-form-state.ts index 614995d43a..d00534f7f4 100644 --- a/web/app/components/datasets/settings/form/hooks/use-form-state.ts +++ b/web/app/components/datasets/settings/form/hooks/use-form-state.ts @@ -6,7 +6,7 @@ import type { IconInfo, SummaryIndexSetting as SummaryIndexSettingType } from '@ import type { RetrievalConfig } from '@/types/app' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -123,12 +123,12 @@ export const useFormState = () => { return if (!name?.trim()) { - Toast.notify({ type: 'error', message: t('form.nameError', { ns: 'datasetSettings' }) }) + toast.error(t('form.nameError', { ns: 'datasetSettings' })) return } if (!isReRankModelSelected({ rerankModelList, retrievalConfig, indexMethod })) { - Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) }) + toast.error(t('datasetConfig.rerankModelRequired', { ns: 'appDebug' })) return } @@ -176,7 +176,7 @@ export const useFormState = () => { } await updateDatasetSetting({ datasetId: currentDataset!.id, body }) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) if (mutateDatasets) { await mutateDatasets() @@ -184,7 +184,7 @@ export const useFormState = () => { } } catch { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) } finally { setLoading(false) diff --git a/web/app/components/devtools/agentation-loader.tsx b/web/app/components/devtools/agentation-loader.tsx new file mode 100644 index 0000000000..87e1b44c87 --- /dev/null +++ b/web/app/components/devtools/agentation-loader.tsx @@ -0,0 +1,13 @@ +'use client' + +import { IS_DEV } from '@/config' +import dynamic from '@/next/dynamic' + +const Agentation = dynamic(() => import('agentation').then(module => module.Agentation), { ssr: false }) + +export function AgentationLoader() { + if (!IS_DEV) + return null + + return +} diff --git a/web/app/components/devtools/react-grab/loader.tsx b/web/app/components/devtools/react-grab/loader.tsx index 3a1ecc6be8..4ee9ad1236 100644 --- a/web/app/components/devtools/react-grab/loader.tsx +++ b/web/app/components/devtools/react-grab/loader.tsx @@ -1,5 +1,5 @@ -import Script from 'next/script' import { IS_DEV } from '@/config' +import Script from '@/next/script' export function ReactGrabLoader() { if (!IS_DEV) diff --git a/web/app/components/devtools/react-scan/loader.tsx b/web/app/components/devtools/react-scan/loader.tsx index a5956d7825..8e933c2b24 100644 --- a/web/app/components/devtools/react-scan/loader.tsx +++ b/web/app/components/devtools/react-scan/loader.tsx @@ -1,5 +1,5 @@ -import Script from 'next/script' import { IS_DEV } from '@/config' +import Script from '@/next/script' export function ReactScanLoader() { if (!IS_DEV) diff --git a/web/app/components/explore/__tests__/index.spec.tsx b/web/app/components/explore/__tests__/index.spec.tsx index cf76593613..5c743928e8 100644 --- a/web/app/components/explore/__tests__/index.spec.tsx +++ b/web/app/components/explore/__tests__/index.spec.tsx @@ -8,7 +8,7 @@ const mockReplace = vi.fn() const mockPush = vi.fn() const mockInstalledAppsData = { installed_apps: [] as const } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, push: mockPush, diff --git a/web/app/components/explore/banner/__tests__/banner-item.spec.tsx b/web/app/components/explore/banner/__tests__/banner-item.spec.tsx index de35814e8e..2d07cbddd8 100644 --- a/web/app/components/explore/banner/__tests__/banner-item.spec.tsx +++ b/web/app/components/explore/banner/__tests__/banner-item.spec.tsx @@ -1,3 +1,4 @@ +import type { ComponentProps } from 'react' import type { Banner } from '@/models/app' import { cleanup, fireEvent, render, screen } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' @@ -5,6 +6,11 @@ import { BannerItem } from '../banner-item' const mockScrollTo = vi.fn() const mockSlideNodes = vi.fn() +const mockTrackEvent = vi.fn() + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) vi.mock('@/app/components/base/carousel', () => ({ useCarousel: () => ({ @@ -48,19 +54,34 @@ class MockResizeObserver { } } +const renderBannerItem = ( + banner: Banner = createMockBanner(), + props: Partial> = {}, +) => { + return render( + , + ) +} + describe('BannerItem', () => { let mockWindowOpen: ReturnType beforeEach(() => { mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) - mockSlideNodes.mockReturnValue([{}, {}, {}]) // 3 slides + mockSlideNodes.mockReturnValue([{}, {}, {}]) vi.stubGlobal('ResizeObserver', MockResizeObserver) Object.defineProperty(window, 'innerWidth', { writable: true, configurable: true, - value: 1400, // Above RESPONSIVE_BREAKPOINT (1200) + value: 1400, }) }) @@ -73,81 +94,51 @@ describe('BannerItem', () => { describe('basic rendering', () => { it('renders banner category', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Featured')).toBeInTheDocument() }) it('renders banner title', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) it('renders banner description', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test banner description text')).toBeInTheDocument() }) it('renders banner image with correct src and alt', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() const image = screen.getByRole('img') expect(image).toHaveAttribute('src', 'https://example.com/image.png') expect(image).toHaveAttribute('alt', 'Test Banner Title') }) it('renders view more text', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('explore.banner.viewMore')).toBeInTheDocument() }) }) describe('click handling', () => { - it('opens banner link in new tab when clicked', () => { + it('opens banner link in new tab and tracks click when clicked', () => { const banner = createMockBanner({ link: 'https://test-link.com' }) - render( - , - ) + renderBannerItem(banner, { sort: 2, language: 'zh-Hans', accountId: 'account-123' }) const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') fireEvent.click(bannerElement!) + expect(mockTrackEvent).toHaveBeenCalledWith('explore_banner_click', expect.objectContaining({ + banner_id: 'banner-1', + title: 'Test Banner Title', + sort: 2, + link: 'https://test-link.com', + page: 'explore', + language: 'zh-Hans', + account_id: 'account-123', + event_time: expect.any(Number), + })) expect(mockWindowOpen).toHaveBeenCalledWith( 'https://test-link.com', '_blank', @@ -155,18 +146,16 @@ describe('BannerItem', () => { ) }) - it('does not open window when banner has no link', () => { + it('tracks click even when banner has no link', () => { const banner = createMockBanner({ link: '' }) - render( - , - ) + renderBannerItem(banner) const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') fireEvent.click(bannerElement!) + expect(mockTrackEvent).toHaveBeenCalledWith('explore_banner_click', expect.objectContaining({ + link: '', + })) expect(mockWindowOpen).not.toHaveBeenCalled() }) }) @@ -174,28 +163,13 @@ describe('BannerItem', () => { describe('slide indicators', () => { it('renders correct number of indicator buttons', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) - - const buttons = screen.getAllByRole('button') - expect(buttons).toHaveLength(3) + renderBannerItem() + expect(screen.getAllByRole('button')).toHaveLength(3) }) it('renders indicator buttons with correct numbers', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('01')).toBeInTheDocument() expect(screen.getByText('02')).toBeInTheDocument() expect(screen.getByText('03')).toBeInTheDocument() @@ -203,13 +177,7 @@ describe('BannerItem', () => { it('calls scrollTo when indicator is clicked', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) + renderBannerItem() const secondIndicator = screen.getByText('02').closest('button') fireEvent.click(secondIndicator!) @@ -219,81 +187,39 @@ describe('BannerItem', () => { it('renders no indicators when no slides', () => { mockSlideNodes.mockReturnValue([]) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.queryByRole('button')).not.toBeInTheDocument() }) }) describe('isPaused prop', () => { it('defaults isPaused to false', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) it('accepts isPaused prop', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem(createMockBanner(), { isPaused: true }) expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) }) describe('responsive behavior', () => { it('sets up ResizeObserver on mount', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(mockResizeObserverObserve).toHaveBeenCalled() }) it('adds resize event listener on mount', () => { const addEventListenerSpy = vi.spyOn(window, 'addEventListener') - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(addEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) addEventListenerSpy.mockRestore() }) it('removes resize event listener on unmount', () => { const removeEventListenerSpy = vi.spyOn(window, 'removeEventListener') - const banner = createMockBanner() - const { unmount } = render( - , - ) + const { unmount } = renderBannerItem() unmount() @@ -308,14 +234,7 @@ describe('BannerItem', () => { value: 1000, }) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) @@ -326,14 +245,7 @@ describe('BannerItem', () => { value: 800, }) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('explore.banner.viewMore')).toBeInTheDocument() }) }) @@ -348,13 +260,8 @@ describe('BannerItem', () => { 'img-src': 'https://example.com/img.png', }, } as Partial) - render( - , - ) + renderBannerItem(banner) expect(screen.getByText('Very Long Category Name')).toBeInTheDocument() }) @@ -367,13 +274,8 @@ describe('BannerItem', () => { 'img-src': 'https://example.com/img.png', }, } as Partial) - render( - , - ) + renderBannerItem(banner) const titleElement = screen.getByText('A Very Long Title That Should Be Truncated Eventually') expect(titleElement).toHaveClass('line-clamp-2') }) @@ -387,13 +289,8 @@ describe('BannerItem', () => { 'img-src': 'https://example.com/img.png', }, } as Partial) - render( - , - ) + renderBannerItem(banner) const descriptionElement = screen.getByText(/A very long description/) expect(descriptionElement).toHaveClass('line-clamp-4') }) @@ -402,56 +299,26 @@ describe('BannerItem', () => { describe('slide calculation', () => { it('calculates next index correctly for first slide', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) - - const buttons = screen.getAllByRole('button') - expect(buttons).toHaveLength(3) + renderBannerItem() + expect(screen.getAllByRole('button')).toHaveLength(3) }) it('handles single slide case', () => { mockSlideNodes.mockReturnValue([{}]) - const banner = createMockBanner() - render( - , - ) - - const buttons = screen.getAllByRole('button') - expect(buttons).toHaveLength(1) + renderBannerItem() + expect(screen.getAllByRole('button')).toHaveLength(1) }) }) describe('wrapper styling', () => { it('has cursor-pointer class', () => { - const banner = createMockBanner() - const { container } = render( - , - ) - + const { container } = renderBannerItem() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('cursor-pointer') }) it('has rounded-2xl class', () => { - const banner = createMockBanner() - const { container } = render( - , - ) - + const { container } = renderBannerItem() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('rounded-2xl') }) diff --git a/web/app/components/explore/banner/__tests__/banner.spec.tsx b/web/app/components/explore/banner/__tests__/banner.spec.tsx index d6d0aa44a8..069aaf02dc 100644 --- a/web/app/components/explore/banner/__tests__/banner.spec.tsx +++ b/web/app/components/explore/banner/__tests__/banner.spec.tsx @@ -6,6 +6,8 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import Banner from '../banner' const mockUseGetBanners = vi.fn() +const mockUseSelector = vi.fn() +const mockTrackEvent = vi.fn() vi.mock('@/service/use-explore', () => ({ useGetBanners: (...args: unknown[]) => mockUseGetBanners(...args), @@ -15,6 +17,14 @@ vi.mock('@/context/i18n', () => ({ useLocale: () => 'en-US', })) +vi.mock('@/context/app-context', () => ({ + useSelector: (...args: unknown[]) => mockUseSelector(...args), +})) + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + vi.mock('@/app/components/base/carousel', () => ({ Carousel: Object.assign( ({ children, onMouseEnter, onMouseLeave, className }: { @@ -54,9 +64,12 @@ vi.mock('@/app/components/base/carousel', () => ({ })) vi.mock('../banner-item', () => ({ - BannerItem: ({ banner, autoplayDelay, isPaused }: { + BannerItem: ({ banner, autoplayDelay, isPaused, sort, language, accountId }: { banner: BannerType autoplayDelay: number + sort: number + language: string + accountId?: string isPaused?: boolean }) => (
    ({ data-banner-id={banner.id} data-autoplay-delay={autoplayDelay} data-is-paused={isPaused} + data-sort={sort} + data-language={language} + data-account-id={accountId} > BannerItem: {' '} @@ -87,6 +103,11 @@ const createMockBanner = (id: string, status: string = 'enabled', title: string describe('Banner', () => { beforeEach(() => { vi.useFakeTimers() + mockUseSelector.mockImplementation(selector => selector({ + userProfile: { + id: 'account-123', + }, + })) }) afterEach(() => { @@ -235,6 +256,59 @@ describe('Banner', () => { expect(screen.getByTestId('carousel')).toHaveClass('rounded-2xl') }) + + it('tracks enabled banner impressions with expected payload', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'Enabled Banner 1'), + createMockBanner('2', 'disabled', 'Disabled Banner'), + createMockBanner('3', 'enabled', 'Enabled Banner 2'), + ], + isLoading: false, + isError: false, + }) + + render() + + expect(mockTrackEvent).toHaveBeenCalledTimes(2) + expect(mockTrackEvent).toHaveBeenNthCalledWith(1, 'explore_banner_impression', expect.objectContaining({ + banner_id: '1', + title: 'Enabled Banner 1', + sort: 1, + link: 'https://example.com', + page: 'explore', + language: 'en-US', + account_id: 'account-123', + event_time: expect.any(Number), + })) + expect(mockTrackEvent).toHaveBeenNthCalledWith(2, 'explore_banner_impression', expect.objectContaining({ + banner_id: '3', + title: 'Enabled Banner 2', + sort: 2, + link: 'https://example.com', + page: 'explore', + language: 'en-US', + account_id: 'account-123', + event_time: expect.any(Number), + })) + }) + + it('does not track impressions when account id is unavailable', () => { + mockUseSelector.mockImplementation(selector => selector({ + userProfile: { + id: '', + }, + })) + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled', 'Enabled Banner 1')], + isLoading: false, + isError: false, + }) + + render() + + expect(mockTrackEvent).not.toHaveBeenCalled() + }) }) describe('hover behavior', () => { @@ -435,8 +509,25 @@ describe('Banner', () => { const bannerItems = screen.getAllByTestId('banner-item') expect(bannerItems[0]).toHaveAttribute('data-banner-id', '1') + expect(bannerItems[0]).toHaveAttribute('data-sort', '1') expect(bannerItems[1]).toHaveAttribute('data-banner-id', '2') + expect(bannerItems[1]).toHaveAttribute('data-sort', '2') expect(bannerItems[2]).toHaveAttribute('data-banner-id', '3') + expect(bannerItems[2]).toHaveAttribute('data-sort', '3') + }) + + it('passes tracking context to banner item', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled', 'Banner 1')], + isLoading: false, + isError: false, + }) + + render() + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-language', 'en-US') + expect(bannerItem).toHaveAttribute('data-account-id', 'account-123') }) }) diff --git a/web/app/components/explore/banner/banner-item.tsx b/web/app/components/explore/banner/banner-item.tsx index d90a1060f9..c1e48bf420 100644 --- a/web/app/components/explore/banner/banner-item.tsx +++ b/web/app/components/explore/banner/banner-item.tsx @@ -4,6 +4,7 @@ import type { Banner } from '@/models/app' import { RiArrowRightLine } from '@remixicon/react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { trackEvent } from '@/app/components/base/amplitude' import { useCarousel } from '@/app/components/base/carousel' import { cn } from '@/utils/classnames' import { IndicatorButton } from './indicator-button' @@ -11,6 +12,9 @@ import { IndicatorButton } from './indicator-button' type BannerItemProps = { banner: Banner autoplayDelay: number + sort: number + language: string + accountId?: string isPaused?: boolean } @@ -20,7 +24,14 @@ const INDICATOR_WIDTH = 20 const INDICATOR_GAP = 8 const MIN_VIEW_MORE_WIDTH = 480 -export const BannerItem: FC = ({ banner, autoplayDelay, isPaused = false }) => { +export const BannerItem: FC = ({ + banner, + autoplayDelay, + sort, + language, + accountId, + isPaused = false, +}) => { const { t } = useTranslation() const { api, selectedIndex } = useCarousel() const { category, title, description, 'img-src': imgSrc } = banner.content @@ -91,9 +102,21 @@ export const BannerItem: FC = ({ banner, autoplayDelay, isPause const handleBannerClick = useCallback(() => { incrementResetKey() + + trackEvent('explore_banner_click', { + banner_id: banner.id, + title: banner.content.title, + sort, + link: banner.link, + page: 'explore', + language, + account_id: accountId, + event_time: Date.now(), + }) + if (banner.link) window.open(banner.link, '_blank', 'noopener,noreferrer') - }, [banner.link, incrementResetKey]) + }, [accountId, banner, incrementResetKey, language, sort]) const handleIndicatorClick = useCallback((index: number) => { incrementResetKey() diff --git a/web/app/components/explore/banner/banner.tsx b/web/app/components/explore/banner/banner.tsx index 4ec0efdb2b..a320bb1974 100644 --- a/web/app/components/explore/banner/banner.tsx +++ b/web/app/components/explore/banner/banner.tsx @@ -1,7 +1,9 @@ import type { FC } from 'react' import * as React from 'react' import { useEffect, useMemo, useRef, useState } from 'react' +import { trackEvent } from '@/app/components/base/amplitude' import { Carousel } from '@/app/components/base/carousel' +import { useSelector } from '@/context/app-context' import { useLocale } from '@/context/i18n' import { useGetBanners } from '@/service/use-explore' import Loading from '../../base/loading' @@ -23,9 +25,11 @@ const LoadingState: FC = () => ( const Banner: FC = () => { const locale = useLocale() const { data: banners, isLoading, isError } = useGetBanners(locale) + const accountId = useSelector(s => s.userProfile.id) const [isHovered, setIsHovered] = useState(false) const [isResizing, setIsResizing] = useState(false) const resizeTimerRef = useRef(null) + const trackedBannerIdsRef = useRef>(new Set()) const enabledBanners = useMemo( () => banners?.filter(banner => banner.status === 'enabled') ?? [], @@ -56,6 +60,28 @@ const Banner: FC = () => { } }, []) + useEffect(() => { + if (!accountId) + return + + enabledBanners.forEach((banner, index) => { + if (trackedBannerIdsRef.current.has(banner.id)) + return + + trackEvent('explore_banner_impression', { + banner_id: banner.id, + title: banner.content.title, + sort: index + 1, + link: banner.link, + page: 'explore', + language: locale, + account_id: accountId, + event_time: Date.now(), + }) + trackedBannerIdsRef.current.add(banner.id) + }) + }, [accountId, enabledBanners, locale]) + if (isLoading) return @@ -77,12 +103,15 @@ const Banner: FC = () => { onMouseLeave={() => setIsHovered(false)} > - {enabledBanners.map(banner => ( + {enabledBanners.map((banner, index) => ( ))} diff --git a/web/app/components/explore/create-app-modal/__tests__/index.spec.tsx b/web/app/components/explore/create-app-modal/__tests__/index.spec.tsx index 62353fb3c1..f389eeab29 100644 --- a/web/app/components/explore/create-app-modal/__tests__/index.spec.tsx +++ b/web/app/components/explore/create-app-modal/__tests__/index.spec.tsx @@ -19,7 +19,7 @@ vi.mock('@emoji-mart/data', () => ({ }, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useParams: () => ({}), })) diff --git a/web/app/components/explore/sidebar/__tests__/index.spec.tsx b/web/app/components/explore/sidebar/__tests__/index.spec.tsx index 36e6ab217c..0ce98f45db 100644 --- a/web/app/components/explore/sidebar/__tests__/index.spec.tsx +++ b/web/app/components/explore/sidebar/__tests__/index.spec.tsx @@ -1,19 +1,23 @@ import type { InstalledApp } from '@/models/explore' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import Toast from '@/app/components/base/toast' import { MediaType } from '@/hooks/use-breakpoints' import { AppModeEnum } from '@/types/app' import SideBar from '../index' +const { mockToastSuccess } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), +})) + const mockSegments = ['apps'] const mockPush = vi.fn() const mockUninstall = vi.fn() const mockUpdatePinStatus = vi.fn() let mockIsPending = false +let mockIsUninstallPending = false let mockInstalledApps: InstalledApp[] = [] let mockMediaType: string = MediaType.pc -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegments: () => mockSegments, useRouter: () => ({ push: mockPush, @@ -36,12 +40,24 @@ vi.mock('@/service/use-explore', () => ({ }), useUninstallApp: () => ({ mutateAsync: mockUninstall, + isPending: mockIsUninstallPending, }), useUpdateAppPinStatus: () => ({ mutateAsync: mockUpdatePinStatus, }), })) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + }, + } +}) + const createInstalledApp = (overrides: Partial = {}): InstalledApp => ({ id: overrides.id ?? 'app-123', uninstallable: overrides.uninstallable ?? false, @@ -67,9 +83,9 @@ describe('SideBar', () => { beforeEach(() => { vi.clearAllMocks() mockIsPending = false + mockIsUninstallPending = false mockInstalledApps = [] mockMediaType = MediaType.pc - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('Rendering', () => { @@ -79,11 +95,19 @@ describe('SideBar', () => { expect(screen.getByText('explore.sidebar.title')).toBeInTheDocument() }) + it('should expose an accessible name for the discovery link when the text is hidden', () => { + mockMediaType = MediaType.mobile + renderSideBar() + + expect(screen.getByRole('link', { name: 'explore.sidebar.title' })).toBeInTheDocument() + }) + it('should render workspace items when installed apps exist', () => { mockInstalledApps = [createInstalledApp()] renderSideBar() expect(screen.getByText('explore.sidebar.webApps')).toBeInTheDocument() + expect(screen.getByRole('region', { name: 'explore.sidebar.webApps' })).toBeInTheDocument() expect(screen.getByText('My App')).toBeInTheDocument() }) @@ -121,6 +145,15 @@ describe('SideBar', () => { const dividers = container.querySelectorAll('[class*="divider"], hr') expect(dividers.length).toBeGreaterThan(0) }) + + it('should render a button for toggling the sidebar and update its accessible name', () => { + renderSideBar() + + const toggleButton = screen.getByRole('button', { name: 'layout.sidebar.collapseSidebar' }) + fireEvent.click(toggleButton) + + expect(screen.getByRole('button', { name: 'layout.sidebar.expandSidebar' })).toBeInTheDocument() + }) }) describe('User Interactions', () => { @@ -135,10 +168,7 @@ describe('SideBar', () => { await waitFor(() => { expect(mockUninstall).toHaveBeenCalledWith('app-123') - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - message: 'common.api.remove', - })) + expect(mockToastSuccess).toHaveBeenCalledWith('common.api.remove') }) }) @@ -152,10 +182,7 @@ describe('SideBar', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-123', isPinned: true }) - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - message: 'common.api.success', - })) + expect(mockToastSuccess).toHaveBeenCalledWith('common.api.success') }) }) @@ -187,6 +214,18 @@ describe('SideBar', () => { expect(mockUninstall).not.toHaveBeenCalled() }) }) + + it('should disable dialog actions while uninstall is pending', async () => { + mockInstalledApps = [createInstalledApp()] + mockIsUninstallPending = true + renderSideBar() + + fireEvent.click(screen.getByTestId('item-operation-trigger')) + fireEvent.click(await screen.findByText('explore.sidebar.action.delete')) + + expect(screen.getByText('common.operation.cancel')).toBeDisabled() + expect(screen.getByText('common.operation.confirm')).toBeDisabled() + }) }) describe('Edge Cases', () => { diff --git a/web/app/components/explore/sidebar/app-nav-item/__tests__/index.spec.tsx b/web/app/components/explore/sidebar/app-nav-item/__tests__/index.spec.tsx index 299c181c98..26af458c55 100644 --- a/web/app/components/explore/sidebar/app-nav-item/__tests__/index.spec.tsx +++ b/web/app/components/explore/sidebar/app-nav-item/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import AppNavItem from '../index' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockPush, }), diff --git a/web/app/components/explore/sidebar/app-nav-item/index.tsx b/web/app/components/explore/sidebar/app-nav-item/index.tsx index 08558578f6..3f3d7a727e 100644 --- a/web/app/components/explore/sidebar/app-nav-item/index.tsx +++ b/web/app/components/explore/sidebar/app-nav-item/index.tsx @@ -2,11 +2,11 @@ import type { AppIconType } from '@/types/app' import { useHover } from 'ahooks' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useRef } from 'react' import AppIcon from '@/app/components/base/app-icon' import ItemOperation from '@/app/components/explore/item-operation' +import { useRouter } from '@/next/navigation' import { cn } from '@/utils/classnames' export type IAppNavItemProps = { diff --git a/web/app/components/explore/sidebar/index.tsx b/web/app/components/explore/sidebar/index.tsx index bafc745b01..095b838e03 100644 --- a/web/app/components/explore/sidebar/index.tsx +++ b/web/app/components/explore/sidebar/index.tsx @@ -1,19 +1,34 @@ 'use client' import { useBoolean } from 'ahooks' -import Link from 'next/link' -import { useSelectedLayoutSegments } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, +} from '@/app/components/base/ui/alert-dialog' +import { ScrollArea } from '@/app/components/base/ui/scroll-area' +import { toast } from '@/app/components/base/ui/toast' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import Link from '@/next/link' +import { useSelectedLayoutSegments } from '@/next/navigation' import { useGetInstalledApps, useUninstallApp, useUpdateAppPinStatus } from '@/service/use-explore' import { cn } from '@/utils/classnames' -import Toast from '../../base/toast' import Item from './app-nav-item' import NoApps from './no-apps' +const expandedSidebarScrollAreaClassNames = { + content: 'space-y-0.5', + scrollbar: 'data-[orientation=vertical]:my-2 data-[orientation=vertical]:[margin-inline-end:-0.75rem]', + viewport: 'overscroll-contain', +} as const + const SideBar = () => { const { t } = useTranslation() const segments = useSelectedLayoutSegments() @@ -21,7 +36,7 @@ const SideBar = () => { const isDiscoverySelected = lastSegment === 'apps' const { data, isPending } = useGetInstalledApps() const installedApps = data?.installed_apps ?? [] - const { mutateAsync: uninstallApp } = useUninstallApp() + const { mutateAsync: uninstallApp, isPending: isUninstalling } = useUninstallApp() const { mutateAsync: updatePinStatus } = useUpdateAppPinStatus() const media = useBreakpoints() @@ -36,30 +51,50 @@ const SideBar = () => { const id = currId await uninstallApp(id) setShowConfirm(false) - Toast.notify({ - type: 'success', - message: t('api.remove', { ns: 'common' }), - }) + toast.success(t('api.remove', { ns: 'common' })) } const handleUpdatePinStatus = async (id: string, isPinned: boolean) => { await updatePinStatus({ appId: id, isPinned }) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) + toast.success(t('api.success', { ns: 'common' })) } const pinnedAppsCount = installedApps.filter(({ is_pinned }) => is_pinned).length + const shouldUseExpandedScrollArea = !isMobile && !isFold + const webAppsLabelId = React.useId() + const installedAppItems = installedApps.map(({ id, is_pinned, uninstallable, app: { name, icon_type, icon, icon_url, icon_background } }, index) => ( + + handleUpdatePinStatus(id, !is_pinned)} + uninstallable={uninstallable} + onDelete={(id) => { + setCurrId(id) + setShowConfirm(true) + }} + /> + {index === pinnedAppsCount - 1 && index !== installedApps.length - 1 && } + + )) + return ( -
    +
    - +
    {!isMobile && !isFold &&
    {t('sidebar.title', { ns: 'explore' })}
    } @@ -73,59 +108,67 @@ const SideBar = () => { )} {installedApps.length > 0 && ( -
    - {!isMobile && !isFold &&

    {t('sidebar.webApps', { ns: 'explore' })}

    } -
    - {installedApps.map(({ id, is_pinned, uninstallable, app: { name, icon_type, icon, icon_url, icon_background } }, index) => ( - - handleUpdatePinStatus(id, !is_pinned)} - uninstallable={uninstallable} - onDelete={(id) => { - setCurrId(id) - setShowConfirm(true) - }} - /> - {index === pinnedAppsCount - 1 && index !== installedApps.length - 1 && } - - ))} -
    -
    - )} - - {!isMobile && ( -
    - {isFold - ? +
    + {!isMobile && !isFold &&

    {t('sidebar.webApps', { ns: 'explore' })}

    } + {shouldUseExpandedScrollArea + ? ( +
    + + {installedAppItems} + +
    + ) : ( - +
    + {installedAppItems} +
    )}
    )} - {showConfirm && ( - setShowConfirm(false)} - /> + {!isMobile && ( +
    + +
    )} + + + +
    + + {t('sidebar.delete.title', { ns: 'explore' })} + + + {t('sidebar.delete.content', { ns: 'explore' })} + +
    + + + {t('operation.cancel', { ns: 'common' })} + + + {t('operation.confirm', { ns: 'common' })} + + +
    +
    ) } diff --git a/web/app/components/goto-anything/__tests__/command-selector.spec.tsx b/web/app/components/goto-anything/__tests__/command-selector.spec.tsx index 56e40a71f0..98c6ac784f 100644 --- a/web/app/components/goto-anything/__tests__/command-selector.spec.tsx +++ b/web/app/components/goto-anything/__tests__/command-selector.spec.tsx @@ -5,7 +5,7 @@ import { Command } from 'cmdk' import * as React from 'react' import CommandSelector from '../command-selector' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => '/app', })) diff --git a/web/app/components/goto-anything/__tests__/context.spec.tsx b/web/app/components/goto-anything/__tests__/context.spec.tsx index c427f76c61..70a30786df 100644 --- a/web/app/components/goto-anything/__tests__/context.spec.tsx +++ b/web/app/components/goto-anything/__tests__/context.spec.tsx @@ -3,7 +3,7 @@ import * as React from 'react' import { GotoAnythingProvider, useGotoAnythingContext } from '../context' let pathnameMock: string | null | undefined = '/' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => pathnameMock, })) diff --git a/web/app/components/goto-anything/__tests__/index.spec.tsx b/web/app/components/goto-anything/__tests__/index.spec.tsx index eb5fa8ccdd..b2050ef9fb 100644 --- a/web/app/components/goto-anything/__tests__/index.spec.tsx +++ b/web/app/components/goto-anything/__tests__/index.spec.tsx @@ -11,7 +11,7 @@ type TestSearchResult = Omit & { } const routerPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: routerPush, }), diff --git a/web/app/components/goto-anything/command-selector.tsx b/web/app/components/goto-anything/command-selector.tsx index bdb641cae6..59373c9e3a 100644 --- a/web/app/components/goto-anything/command-selector.tsx +++ b/web/app/components/goto-anything/command-selector.tsx @@ -1,9 +1,9 @@ import type { FC } from 'react' import type { ActionItem } from './actions/types' import { Command } from 'cmdk' -import { usePathname } from 'next/navigation' import { useEffect, useMemo } from 'react' import { useTranslation } from 'react-i18next' +import { usePathname } from '@/next/navigation' import { slashCommandRegistry } from './actions/commands/registry' type Props = { diff --git a/web/app/components/goto-anything/context.tsx b/web/app/components/goto-anything/context.tsx index 5c2bf3cb6b..28fb08ac17 100644 --- a/web/app/components/goto-anything/context.tsx +++ b/web/app/components/goto-anything/context.tsx @@ -1,9 +1,9 @@ 'use client' import type { ReactNode } from 'react' -import { usePathname } from 'next/navigation' import * as React from 'react' import { createContext, useContext, useEffect, useState } from 'react' +import { usePathname } from '@/next/navigation' import { isInWorkflowPage } from '../workflow/constants' /** diff --git a/web/app/components/goto-anything/hooks/__tests__/use-goto-anything-navigation.spec.ts b/web/app/components/goto-anything/hooks/__tests__/use-goto-anything-navigation.spec.ts index 1ac3bbc17c..c8a6a4a13c 100644 --- a/web/app/components/goto-anything/hooks/__tests__/use-goto-anything-navigation.spec.ts +++ b/web/app/components/goto-anything/hooks/__tests__/use-goto-anything-navigation.spec.ts @@ -16,7 +16,7 @@ type MockCommandResult = { let mockFindCommandResult: MockCommandResult = null -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), diff --git a/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts b/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts index 73be6cd3ee..9c9871fa1d 100644 --- a/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts +++ b/web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts @@ -3,9 +3,9 @@ import type { RefObject } from 'react' import type { Plugin } from '../../plugins/types' import type { ActionItem, SearchResult } from '../actions/types' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation' +import { useRouter } from '@/next/navigation' import { slashCommandRegistry } from '../actions/commands/registry' export type UseGotoAnythingNavigationReturn = { diff --git a/web/app/components/header/__tests__/header-wrapper.spec.tsx b/web/app/components/header/__tests__/header-wrapper.spec.tsx index b1948e0992..cdb6a7a849 100644 --- a/web/app/components/header/__tests__/header-wrapper.spec.tsx +++ b/web/app/components/header/__tests__/header-wrapper.spec.tsx @@ -1,10 +1,10 @@ import { act, render, screen } from '@testing-library/react' -import { usePathname } from 'next/navigation' import { vi } from 'vitest' import { useEventEmitterContextContext } from '@/context/event-emitter' +import { usePathname } from '@/next/navigation' import HeaderWrapper from '../header-wrapper' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(), })) diff --git a/web/app/components/header/__tests__/index.spec.tsx b/web/app/components/header/__tests__/index.spec.tsx index 93ab7fb535..16e0854339 100644 --- a/web/app/components/header/__tests__/index.spec.tsx +++ b/web/app/components/header/__tests__/index.spec.tsx @@ -52,7 +52,7 @@ vi.mock('@/context/workspace-context-provider', () => ({ WorkspaceProvider: ({ children }: { children?: React.ReactNode }) => children, })) -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href }: { children?: React.ReactNode, href?: string }) => {children}, })) diff --git a/web/app/components/header/account-about/index.tsx b/web/app/components/header/account-about/index.tsx index b80cbb8f03..09ab89fc88 100644 --- a/web/app/components/header/account-about/index.tsx +++ b/web/app/components/header/account-about/index.tsx @@ -2,15 +2,15 @@ import type { LangGeniusVersionResponse } from '@/models/common' import { RiCloseLine } from '@remixicon/react' import dayjs from 'dayjs' -import Link from 'next/link' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import DifyLogo from '@/app/components/base/logo/dify-logo' import Modal from '@/app/components/base/modal' import { IS_CE_EDITION } from '@/config' - import { useGlobalPublicStore } from '@/context/global-public-context' +import Link from '@/next/link' + type IAccountSettingProps = { langGeniusVersionInfo: LangGeniusVersionResponse onCancel: () => void diff --git a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx index e1d4c45810..9d4226c33a 100644 --- a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx @@ -3,12 +3,12 @@ import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { useRouter } from 'next/navigation' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { useLogout } from '@/service/use-common' import AppSelector from '../index' @@ -53,8 +53,8 @@ vi.mock('@/service/use-common', () => ({ useLogout: vi.fn(), })) -vi.mock('next/navigation', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/next/navigation', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, useRouter: vi.fn(), @@ -69,6 +69,7 @@ vi.mock('@/context/i18n', () => ({ const { mockConfig, mockEnv } = vi.hoisted(() => ({ mockConfig: { IS_CLOUD_EDITION: false, + AMPLITUDE_API_KEY: '', ZENDESK_WIDGET_KEY: '', SUPPORT_EMAIL_ADDRESS: '', }, @@ -80,6 +81,8 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ })) vi.mock('@/config', () => ({ get IS_CLOUD_EDITION() { return mockConfig.IS_CLOUD_EDITION }, + get AMPLITUDE_API_KEY() { return mockConfig.AMPLITUDE_API_KEY }, + get isAmplitudeEnabled() { return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY }, get ZENDESK_WIDGET_KEY() { return mockConfig.ZENDESK_WIDGET_KEY }, get SUPPORT_EMAIL_ADDRESS() { return mockConfig.SUPPORT_EMAIL_ADDRESS }, IS_DEV: false, diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 0a5779839e..1697433ac4 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -1,8 +1,6 @@ 'use client' import type { MouseEventHandler, ReactNode } from 'react' -import Link from 'next/link' -import { useRouter } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' @@ -18,6 +16,8 @@ import { useDocLink } from '@/context/i18n' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { env } from '@/env' +import Link from '@/next/link' +import { useRouter } from '@/next/navigation' import { useLogout } from '@/service/use-common' import { cn } from '@/utils/classnames' import AccountAbout from '../account-about' diff --git a/web/app/components/header/account-setting/Integrations-page/index.tsx b/web/app/components/header/account-setting/Integrations-page/index.tsx index ef234b5db7..29d0d9fcd3 100644 --- a/web/app/components/header/account-setting/Integrations-page/index.tsx +++ b/web/app/components/header/account-setting/Integrations-page/index.tsx @@ -1,7 +1,7 @@ 'use client' -import Link from 'next/link' import { useTranslation } from 'react-i18next' +import Link from '@/next/link' import { useAccountIntegrates } from '@/service/use-common' import { cn } from '@/utils/classnames' import s from './index.module.css' diff --git a/web/app/components/header/account-setting/__tests__/index.spec.tsx b/web/app/components/header/account-setting/__tests__/index.spec.tsx index 38cbb58a1b..279af0b114 100644 --- a/web/app/components/header/account-setting/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/index.spec.tsx @@ -27,7 +27,7 @@ vi.mock('@/context/app-context', async (importOriginal) => { } }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: vi.fn(), replace: vi.fn(), @@ -315,14 +315,14 @@ describe('AccountSetting', () => { it('should handle scroll event in panel', () => { // Act renderAccountSetting() - const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto') + const scrollContainer = screen.getByRole('dialog').querySelector('.overscroll-contain') // Assert expect(scrollContainer).toBeInTheDocument() if (scrollContainer) { // Scroll down fireEvent.scroll(scrollContainer, { target: { scrollTop: 100 } }) - expect(scrollContainer).toHaveClass('overflow-y-auto') + expect(scrollContainer).toHaveClass('overscroll-contain') // Scroll back up fireEvent.scroll(scrollContainer, { target: { scrollTop: 0 } }) diff --git a/web/app/components/header/account-setting/menu-dialog.dialog.spec.tsx b/web/app/components/header/account-setting/__tests__/menu-dialog.dialog.spec.tsx similarity index 96% rename from web/app/components/header/account-setting/menu-dialog.dialog.spec.tsx rename to web/app/components/header/account-setting/__tests__/menu-dialog.dialog.spec.tsx index 627b764eb2..db8aec4ec1 100644 --- a/web/app/components/header/account-setting/menu-dialog.dialog.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/menu-dialog.dialog.spec.tsx @@ -1,6 +1,6 @@ import type { ReactNode } from 'react' import { render } from '@testing-library/react' -import MenuDialog from './menu-dialog' +import MenuDialog from '../menu-dialog' type DialogProps = { children: ReactNode diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx index efe6c46dcc..5f1492f14a 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx @@ -78,6 +78,7 @@ const ApiBasedExtensionModal: FC = ({
    diff --git a/web/app/components/header/account-setting/api-based-extension-page/selector.tsx b/web/app/components/header/account-setting/api-based-extension-page/selector.tsx index 38acb73154..62052aece6 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/selector.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/selector.tsx @@ -69,7 +69,7 @@ const ApiBasedExtensionSelector: FC = ({ ) } - +
    diff --git a/web/app/components/header/account-setting/data-source-page-new/__tests__/install-from-marketplace.spec.tsx b/web/app/components/header/account-setting/data-source-page-new/__tests__/install-from-marketplace.spec.tsx index daf9d3b988..a9d81a12e0 100644 --- a/web/app/components/header/account-setting/data-source-page-new/__tests__/install-from-marketplace.spec.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/__tests__/install-from-marketplace.spec.tsx @@ -16,7 +16,7 @@ vi.mock('next-themes', () => ({ useTheme: vi.fn(), })) -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href }: { children: React.ReactNode, href: string }) => ( {children} ), diff --git a/web/app/components/header/account-setting/data-source-page-new/configure.tsx b/web/app/components/header/account-setting/data-source-page-new/configure.tsx index a3dba783e1..484338d333 100644 --- a/web/app/components/header/account-setting/data-source-page-new/configure.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/configure.tsx @@ -84,7 +84,7 @@ const Configure = ({ {t('dataSource.configure', { ns: 'common' })} - +
    { !!canOAuth && ( @@ -104,7 +104,7 @@ const Configure = ({ } { !!canApiKey && !!canOAuth && ( -
    +
    OR
    diff --git a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx index f02e276f55..1a1ca19c3e 100644 --- a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx @@ -4,7 +4,6 @@ import { RiArrowRightUpLine, } from '@remixicon/react' import { useTheme } from 'next-themes' -import Link from 'next/link' import { memo, useCallback, @@ -15,6 +14,7 @@ import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' import ProviderCard from '@/app/components/plugins/provider-card' +import Link from '@/next/link' import { cn } from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' import { diff --git a/web/app/components/header/account-setting/data-source-page-new/operator.tsx b/web/app/components/header/account-setting/data-source-page-new/operator.tsx index 14bdee4fd0..c5b2a948de 100644 --- a/web/app/components/header/account-setting/data-source-page-new/operator.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/operator.tsx @@ -39,7 +39,7 @@ const Operator = ({ text: (
    -
    {t('auth.setDefault', { ns: 'plugin' })}
    +
    {t('auth.setDefault', { ns: 'plugin' })}
    ), }, @@ -51,7 +51,7 @@ const Operator = ({ text: (
    -
    {t('operation.rename', { ns: 'common' })}
    +
    {t('operation.rename', { ns: 'common' })}
    ), }, @@ -66,7 +66,7 @@ const Operator = ({ text: (
    -
    {t('operation.edit', { ns: 'common' })}
    +
    {t('operation.edit', { ns: 'common' })}
    ), }, @@ -81,7 +81,7 @@ const Operator = ({ text: (
    -
    {t('dataSource.notion.changeAuthorizedPages', { ns: 'common' })}
    +
    {t('dataSource.notion.changeAuthorizedPages', { ns: 'common' })}
    ), }, @@ -98,7 +98,7 @@ const Operator = ({ text: (
    -
    +
    {t('operation.remove', { ns: 'common' })}
    @@ -122,7 +122,7 @@ const Operator = ({ items={items} secondItems={secondItems} onSelect={handleSelect} - popupClassName="z-[61]" + popupClassName="z-[1002]" triggerProps={{ size: 'l', }} diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/__tests__/index.spec.tsx deleted file mode 100644 index dad82d81b9..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/__tests__/index.spec.tsx +++ /dev/null @@ -1,462 +0,0 @@ -import type { UseQueryResult } from '@tanstack/react-query' -import type { AppContextValue } from '@/context/app-context' -import type { DataSourceNotion as TDataSourceNotion } from '@/models/common' -import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' -import { useAppContext } from '@/context/app-context' -import { useDataSourceIntegrates, useInvalidDataSourceIntegrates, useNotionConnection } from '@/service/use-common' -import DataSourceNotion from '../index' - -/** - * DataSourceNotion Component Tests - * Using Unit approach with real Panel and sibling components to test Notion integration logic. - */ - -type MockQueryResult = UseQueryResult - -// Mock dependencies -vi.mock('@/context/app-context', () => ({ - useAppContext: vi.fn(), -})) - -vi.mock('@/service/common', () => ({ - syncDataSourceNotion: vi.fn(), - updateDataSourceNotionAction: vi.fn(), -})) - -vi.mock('@/service/use-common', () => ({ - useDataSourceIntegrates: vi.fn(), - useNotionConnection: vi.fn(), - useInvalidDataSourceIntegrates: vi.fn(), -})) - -describe('DataSourceNotion Component', () => { - const mockWorkspaces: TDataSourceNotion[] = [ - { - id: 'ws-1', - provider: 'notion', - is_bound: true, - source_info: { - workspace_name: 'Workspace 1', - workspace_icon: 'https://example.com/icon-1.png', - workspace_id: 'notion-ws-1', - total: 10, - pages: [], - }, - }, - ] - - const baseAppContext: AppContextValue = { - userProfile: { id: 'test-user-id', name: 'test-user', email: 'test@example.com', avatar: '', avatar_url: '', is_password_set: true }, - mutateUserProfile: vi.fn(), - currentWorkspace: { id: 'ws-id', name: 'Workspace', plan: 'basic', status: 'normal', created_at: 0, role: 'owner', providers: [], trial_credits: 0, trial_credits_used: 0, next_credit_reset_date: 0 }, - isCurrentWorkspaceManager: true, - isCurrentWorkspaceOwner: true, - isCurrentWorkspaceEditor: true, - isCurrentWorkspaceDatasetOperator: false, - mutateCurrentWorkspace: vi.fn(), - langGeniusVersionInfo: { current_version: '0.1.0', latest_version: '0.1.1', version: '0.1.1', release_date: '', release_notes: '', can_auto_update: false, current_env: 'test' }, - useSelector: vi.fn(), - isLoadingCurrentWorkspace: false, - isValidatingCurrentWorkspace: false, - } - - /* eslint-disable-next-line ts/no-explicit-any */ - const mockQuerySuccess = (data: T): MockQueryResult => ({ data, isSuccess: true, isError: false, isLoading: false, isPending: false, status: 'success', error: null, fetchStatus: 'idle' } as any) - /* eslint-disable-next-line ts/no-explicit-any */ - const mockQueryPending = (): MockQueryResult => ({ data: undefined, isSuccess: false, isError: false, isLoading: true, isPending: true, status: 'pending', error: null, fetchStatus: 'fetching' } as any) - - const originalLocation = window.location - - beforeEach(() => { - vi.clearAllMocks() - vi.mocked(useAppContext).mockReturnValue(baseAppContext) - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: [] })) - vi.mocked(useNotionConnection).mockReturnValue(mockQueryPending()) - vi.mocked(useInvalidDataSourceIntegrates).mockReturnValue(vi.fn()) - - const locationMock = { href: '', assign: vi.fn() } - Object.defineProperty(window, 'location', { value: locationMock, writable: true, configurable: true }) - - // Clear document body to avoid toast leaks between tests - document.body.innerHTML = '' - }) - - afterEach(() => { - Object.defineProperty(window, 'location', { value: originalLocation, writable: true, configurable: true }) - }) - - const getWorkspaceItem = (name: string) => { - const nameEl = screen.getByText(name) - return (nameEl.closest('div[class*="workspace-item"]') || nameEl.parentElement) as HTMLElement - } - - describe('Rendering', () => { - it('should render with no workspaces initially and call integration hook', () => { - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.title')).toBeInTheDocument() - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: undefined }) - }) - - it('should render with provided workspaces and pass initialData to hook', () => { - // Arrange - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.connectedWorkspace')).toBeInTheDocument() - expect(screen.getByText('Workspace 1')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.connected')).toBeInTheDocument() - expect(screen.getByAltText('workspace icon')).toHaveAttribute('src', 'https://example.com/icon-1.png') - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: { data: mockWorkspaces } }) - }) - - it('should handle workspaces prop being an empty array', () => { - // Act - render() - - // Assert - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: { data: [] } }) - }) - - it('should handle optional workspaces configurations', () => { - // Branch: workspaces passed as undefined - const { rerender } = render() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: undefined }) - - // Branch: workspaces passed as null - /* eslint-disable-next-line ts/no-explicit-any */ - rerender() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: undefined }) - - // Branch: workspaces passed as [] - rerender() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: { data: [] } }) - }) - - it('should handle cases where integrates data is loading or broken', () => { - // Act (Loading) - const { rerender } = render() - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQueryPending()) - rerender() - // Assert - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - - // Act (Broken) - const brokenData = {} as { data: TDataSourceNotion[] } - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess(brokenData)) - rerender() - // Assert - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - }) - - it('should handle integrates being nullish', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: undefined, isSuccess: true } as any) - render() - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - }) - - it('should handle integrates data being nullish', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: { data: null }, isSuccess: true } as any) - render() - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - }) - - it('should handle integrates data being valid', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: { data: [{ id: '1', is_bound: true, source_info: { workspace_name: 'W', workspace_icon: 'https://example.com/i.png', total: 1, pages: [] } }] }, isSuccess: true } as any) - render() - expect(screen.getByText('common.dataSource.notion.connectedWorkspace')).toBeInTheDocument() - }) - - it('should cover all possible falsy/nullish branches for integrates and workspaces', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - const { rerender } = render() - - const integratesCases = [ - undefined, - null, - {}, - { data: null }, - { data: undefined }, - { data: [] }, - { data: [mockWorkspaces[0]] }, - { data: false }, - { data: 0 }, - { data: '' }, - 123, - 'string', - false, - ] - - integratesCases.forEach((val) => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: val, isSuccess: true } as any) - /* eslint-disable-next-line ts/no-explicit-any */ - rerender() - }) - - expect(useDataSourceIntegrates).toHaveBeenCalled() - }) - }) - - describe('User Permissions', () => { - it('should pass readOnly as false when user is a manager', () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ ...baseAppContext, isCurrentWorkspaceManager: true }) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.title').closest('div')).not.toHaveClass('grayscale') - }) - - it('should pass readOnly as true when user is NOT a manager', () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ ...baseAppContext, isCurrentWorkspaceManager: false }) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.connect')).toHaveClass('opacity-50', 'grayscale') - }) - }) - - describe('Configure and Auth Actions', () => { - it('should handle configure action when user is workspace manager', () => { - // Arrange - render() - - // Act - fireEvent.click(screen.getByText('common.dataSource.connect')) - - // Assert - expect(useNotionConnection).toHaveBeenCalledWith(true) - }) - - it('should block configure action when user is NOT workspace manager', () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ ...baseAppContext, isCurrentWorkspaceManager: false }) - render() - - // Act - fireEvent.click(screen.getByText('common.dataSource.connect')) - - // Assert - expect(useNotionConnection).toHaveBeenCalledWith(false) - }) - - it('should redirect if auth URL is available when "Auth Again" is clicked', async () => { - // Arrange - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'http://auth-url' })) - render() - - // Act - const workspaceItem = getWorkspaceItem('Workspace 1') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - fireEvent.click(authAgainBtn) - - // Assert - expect(window.location.href).toBe('http://auth-url') - }) - - it('should trigger connection flow if URL is missing when "Auth Again" is clicked', async () => { - // Arrange - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - render() - - // Act - const workspaceItem = getWorkspaceItem('Workspace 1') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - fireEvent.click(authAgainBtn) - - // Assert - expect(useNotionConnection).toHaveBeenCalledWith(true) - }) - }) - - describe('Side Effects (Redirection and Toast)', () => { - it('should redirect automatically when connection data returns an http URL', async () => { - // Arrange - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'http://redirect-url' })) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('http://redirect-url') - }) - }) - - it('should show toast notification when connection data is "internal"', async () => { - // Arrange - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'internal' })) - - // Act - render() - - // Assert - expect(await screen.findByText('common.dataSource.notion.integratedAlert')).toBeInTheDocument() - }) - - it('should handle various data types and missing properties in connection data correctly', async () => { - // Arrange & Act (Unknown string) - const { rerender } = render() - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'unknown' })) - rerender() - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - expect(screen.queryByText('common.dataSource.notion.integratedAlert')).not.toBeInTheDocument() - }) - - // Act (Broken object) - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({} as any)) - rerender() - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - - // Act (Non-string) - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 123 } as any)) - rerender() - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - }) - - it('should redirect if data starts with "http" even if it is just "http"', async () => { - // Arrange - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'http' })) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('http') - }) - }) - - it('should skip side effect logic if connection data is an object but missing the "data" property', async () => { - // Arrange - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue({} as any) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - }) - - it('should skip side effect logic if data.data is falsy', async () => { - // Arrange - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue({ data: { data: null } } as any) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - }) - }) - - describe('Additional Action Edge Cases', () => { - it.each([ - undefined, - null, - {}, - { data: undefined }, - { data: null }, - { data: '' }, - { data: 0 }, - { data: false }, - { data: 'http' }, - { data: 'internal' }, - { data: 'unknown' }, - ])('should cover connection data branch: %s', async (val) => { - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue({ data: val, isSuccess: true } as any) - - render() - - // Trigger handleAuthAgain with these values - const workspaceItem = getWorkspaceItem('Workspace 1') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - fireEvent.click(authAgainBtn) - - expect(useNotionConnection).toHaveBeenCalled() - }) - }) - - describe('Edge Cases in Workspace Data', () => { - it('should render correctly with missing source_info optional fields', async () => { - // Arrange - const workspaceWithMissingInfo: TDataSourceNotion = { - id: 'ws-2', - provider: 'notion', - is_bound: false, - source_info: { workspace_name: 'Workspace 2', workspace_id: 'notion-ws-2', workspace_icon: null, pages: [] }, - } - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: [workspaceWithMissingInfo] })) - - // Act - render() - - // Assert - expect(screen.getByText('Workspace 2')).toBeInTheDocument() - - const workspaceItem = getWorkspaceItem('Workspace 2') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - - expect(await screen.findByText('0 common.dataSource.notion.pagesAuthorized')).toBeInTheDocument() - }) - - it('should display inactive status correctly for unbound workspaces', () => { - // Arrange - const inactiveWS: TDataSourceNotion = { - id: 'ws-3', - provider: 'notion', - is_bound: false, - source_info: { workspace_name: 'Workspace 3', workspace_icon: 'https://example.com/icon-3.png', workspace_id: 'notion-ws-3', total: 5, pages: [] }, - } - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: [inactiveWS] })) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.disconnected')).toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx deleted file mode 100644 index 0959383f29..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx +++ /dev/null @@ -1,103 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { DataSourceNotion as TDataSourceNotion } from '@/models/common' -import { noop } from 'es-toolkit/function' -import * as React from 'react' -import { useEffect, useState } from 'react' -import { useTranslation } from 'react-i18next' -import NotionIcon from '@/app/components/base/notion-icon' -import Toast from '@/app/components/base/toast' -import { useAppContext } from '@/context/app-context' -import { useDataSourceIntegrates, useNotionConnection } from '@/service/use-common' -import Panel from '../panel' -import { DataSourceType } from '../panel/types' - -const Icon: FC<{ - src: string - name: string - className: string -}> = ({ src, name, className }) => { - return ( - - ) -} -type Props = { - workspaces?: TDataSourceNotion[] -} - -const DataSourceNotion: FC = ({ - workspaces, -}) => { - const { isCurrentWorkspaceManager } = useAppContext() - const [canConnectNotion, setCanConnectNotion] = useState(false) - const { data: integrates } = useDataSourceIntegrates({ - initialData: workspaces ? { data: workspaces } : undefined, - }) - const { data } = useNotionConnection(canConnectNotion) - const { t } = useTranslation() - - const resolvedWorkspaces = integrates?.data ?? [] - const connected = !!resolvedWorkspaces.length - - const handleConnectNotion = () => { - if (!isCurrentWorkspaceManager) - return - - setCanConnectNotion(true) - } - - const handleAuthAgain = () => { - if (data?.data) - window.location.href = data.data - else - setCanConnectNotion(true) - } - - useEffect(() => { - if (data && 'data' in data) { - if (data.data && typeof data.data === 'string' && data.data.startsWith('http')) { - window.location.href = data.data - } - else if (data.data === 'internal') { - Toast.notify({ - type: 'info', - message: t('dataSource.notion.integratedAlert', { ns: 'common' }), - }) - } - } - }, [data, t]) - - return ( - ({ - id: workspace.id, - logo: ({ className }: { className: string }) => ( - - ), - name: workspace.source_info.workspace_name, - isActive: workspace.is_bound, - notionConfig: { - total: workspace.source_info.total || 0, - }, - }))} - onRemove={noop} // handled in operation/index.tsx - notionActions={{ - onChangeAuthorizedPage: handleAuthAgain, - }} - /> - ) -} -export default React.memo(DataSourceNotion) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/__tests__/index.spec.tsx deleted file mode 100644 index f433b10020..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/__tests__/index.spec.tsx +++ /dev/null @@ -1,137 +0,0 @@ -import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' -import { syncDataSourceNotion, updateDataSourceNotionAction } from '@/service/common' -import { useInvalidDataSourceIntegrates } from '@/service/use-common' -import Operate from '../index' - -/** - * Operate Component (Notion) Tests - * This component provides actions like Sync, Change Pages, and Remove for Notion data sources. - */ - -// Mock services and toast -vi.mock('@/service/common', () => ({ - syncDataSourceNotion: vi.fn(), - updateDataSourceNotionAction: vi.fn(), -})) - -vi.mock('@/service/use-common', () => ({ - useInvalidDataSourceIntegrates: vi.fn(), -})) - -describe('Operate Component (Notion)', () => { - const mockPayload = { - id: 'test-notion-id', - total: 5, - } - const mockOnAuthAgain = vi.fn() - const mockInvalidate = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - vi.mocked(useInvalidDataSourceIntegrates).mockReturnValue(mockInvalidate) - vi.mocked(syncDataSourceNotion).mockResolvedValue({ result: 'success' }) - vi.mocked(updateDataSourceNotionAction).mockResolvedValue({ result: 'success' }) - }) - - describe('Rendering', () => { - it('should render the menu button initially', () => { - // Act - const { container } = render() - - // Assert - const menuButton = within(container).getByRole('button') - expect(menuButton).toBeInTheDocument() - expect(menuButton).not.toHaveClass('bg-state-base-hover') - }) - - it('should open the menu and show all options when clicked', async () => { - // Arrange - const { container } = render() - const menuButton = within(container).getByRole('button') - - // Act - fireEvent.click(menuButton) - - // Assert - expect(await screen.findByText('common.dataSource.notion.changeAuthorizedPages')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.sync')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.remove')).toBeInTheDocument() - expect(screen.getByText(/5/)).toBeInTheDocument() - expect(screen.getByText(/common.dataSource.notion.pagesAuthorized/)).toBeInTheDocument() - expect(menuButton).toHaveClass('bg-state-base-hover') - }) - }) - - describe('Menu Actions', () => { - it('should call onAuthAgain when Change Authorized Pages is clicked', async () => { - // Arrange - const { container } = render() - fireEvent.click(within(container).getByRole('button')) - const option = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - - // Act - fireEvent.click(option) - - // Assert - expect(mockOnAuthAgain).toHaveBeenCalledTimes(1) - }) - - it('should call handleSync, show success toast, and invalidate cache when Sync is clicked', async () => { - // Arrange - const { container } = render() - fireEvent.click(within(container).getByRole('button')) - const syncBtn = await screen.findByText('common.dataSource.notion.sync') - - // Act - fireEvent.click(syncBtn) - - // Assert - await waitFor(() => { - expect(syncDataSourceNotion).toHaveBeenCalledWith({ - url: `/oauth/data-source/notion/${mockPayload.id}/sync`, - }) - }) - expect((await screen.findAllByText('common.api.success')).length).toBeGreaterThan(0) - expect(mockInvalidate).toHaveBeenCalledTimes(1) - }) - - it('should call handleRemove, show success toast, and invalidate cache when Remove is clicked', async () => { - // Arrange - const { container } = render() - fireEvent.click(within(container).getByRole('button')) - const removeBtn = await screen.findByText('common.dataSource.notion.remove') - - // Act - fireEvent.click(removeBtn) - - // Assert - await waitFor(() => { - expect(updateDataSourceNotionAction).toHaveBeenCalledWith({ - url: `/data-source/integrates/${mockPayload.id}/disable`, - }) - }) - expect((await screen.findAllByText('common.api.success')).length).toBeGreaterThan(0) - expect(mockInvalidate).toHaveBeenCalledTimes(1) - }) - }) - - describe('State Transitions', () => { - it('should toggle the open class on the button based on menu visibility', async () => { - // Arrange - const { container } = render() - const menuButton = within(container).getByRole('button') - - // Act (Open) - fireEvent.click(menuButton) - // Assert - expect(menuButton).toHaveClass('bg-state-base-hover') - - // Act (Close - click again) - fireEvent.click(menuButton) - // Assert - await waitFor(() => { - expect(menuButton).not.toHaveClass('bg-state-base-hover') - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx deleted file mode 100644 index 043eb3c846..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx +++ /dev/null @@ -1,103 +0,0 @@ -'use client' -import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' -import { - RiDeleteBinLine, - RiLoopLeftLine, - RiMoreFill, - RiStickyNoteAddLine, -} from '@remixicon/react' -import { Fragment } from 'react' -import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' -import { syncDataSourceNotion, updateDataSourceNotionAction } from '@/service/common' -import { useInvalidDataSourceIntegrates } from '@/service/use-common' -import { cn } from '@/utils/classnames' - -type OperateProps = { - payload: { - id: string - total: number - } - onAuthAgain: () => void -} -export default function Operate({ - payload, - onAuthAgain, -}: OperateProps) { - const { t } = useTranslation() - const invalidateDataSourceIntegrates = useInvalidDataSourceIntegrates() - - const updateIntegrates = () => { - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - invalidateDataSourceIntegrates() - } - const handleSync = async () => { - await syncDataSourceNotion({ url: `/oauth/data-source/notion/${payload.id}/sync` }) - updateIntegrates() - } - const handleRemove = async () => { - await updateDataSourceNotionAction({ url: `/data-source/integrates/${payload.id}/disable` }) - updateIntegrates() - } - - return ( - - { - ({ open }) => ( - <> - - - - - -
    - -
    - -
    -
    {t('dataSource.notion.changeAuthorizedPages', { ns: 'common' })}
    -
    - {payload.total} - {' '} - {t('dataSource.notion.pagesAuthorized', { ns: 'common' })} -
    -
    -
    -
    - -
    - -
    {t('dataSource.notion.sync', { ns: 'common' })}
    -
    -
    -
    - -
    -
    - -
    {t('dataSource.notion.remove', { ns: 'common' })}
    -
    -
    -
    -
    -
    - - ) - } -
    - ) -} diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-firecrawl-modal.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-firecrawl-modal.spec.tsx deleted file mode 100644 index dadda4a349..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-firecrawl-modal.spec.tsx +++ /dev/null @@ -1,204 +0,0 @@ -import type { CommonResponse } from '@/models/common' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { createDataSourceApiKeyBinding } from '@/service/datasets' -import ConfigFirecrawlModal from '../config-firecrawl-modal' - -/** - * ConfigFirecrawlModal Component Tests - * Tests validation, save logic, and basic rendering for the Firecrawl configuration modal. - */ - -vi.mock('@/service/datasets', () => ({ - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('ConfigFirecrawlModal Component', () => { - const mockOnCancel = vi.fn() - const mockOnSaved = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Initial Rendering', () => { - it('should render the modal with all fields and buttons', () => { - // Act - render() - - // Assert - expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument() - expect(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder')).toBeInTheDocument() - expect(screen.getByPlaceholderText('https://api.firecrawl.dev')).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.save/i })).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.cancel/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /datasetCreation\.firecrawl\.getApiKeyLinkText/i })).toHaveAttribute('href', 'https://www.firecrawl.dev/account') - }) - }) - - describe('Form Interactions', () => { - it('should update state when input fields change', async () => { - // Arrange - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder') - const baseUrlInput = screen.getByPlaceholderText('https://api.firecrawl.dev') - - // Act - fireEvent.change(apiKeyInput, { target: { value: 'firecrawl-key' } }) - fireEvent.change(baseUrlInput, { target: { value: 'https://custom.firecrawl.dev' } }) - - // Assert - expect(apiKeyInput).toHaveValue('firecrawl-key') - expect(baseUrlInput).toHaveValue('https://custom.firecrawl.dev') - }) - - it('should call onCancel when cancel button is clicked', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - expect(mockOnCancel).toHaveBeenCalled() - }) - }) - - describe('Validation', () => { - it('should show error when saving without API Key', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.fieldRequired:{"field":"API Key"}')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - - it('should show error for invalid Base URL format', async () => { - const user = userEvent.setup() - // Arrange - render() - const baseUrlInput = screen.getByPlaceholderText('https://api.firecrawl.dev') - - // Act - await user.type(baseUrlInput, 'ftp://invalid-url.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.urlError')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Saving Logic', () => { - it('should save successfully with valid API Key and custom URL', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'valid-key') - await user.type(screen.getByPlaceholderText('https://api.firecrawl.dev'), 'http://my-firecrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith({ - category: 'website', - provider: 'firecrawl', - credentials: { - auth_type: 'bearer', - config: { - api_key: 'valid-key', - base_url: 'http://my-firecrawl.com', - }, - }, - }) - }) - await waitFor(() => { - expect(screen.getByText('common.api.success')).toBeInTheDocument() - expect(mockOnSaved).toHaveBeenCalled() - }) - }) - - it('should use default Base URL if none is provided during save', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'test-key') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://api.firecrawl.dev', - }), - }), - })) - }) - }) - - it('should ignore multiple save clicks while saving is in progress', async () => { - const user = userEvent.setup() - // Arrange - let resolveSave: (value: CommonResponse) => void - const savePromise = new Promise((resolve) => { - resolveSave = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(savePromise) - render() - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'test-key') - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - - // Act - await user.click(saveBtn) - await user.click(saveBtn) - - // Assert - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveSave!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1)) - }) - - it('should accept base_url starting with https://', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'test-key') - await user.type(screen.getByPlaceholderText('https://api.firecrawl.dev'), 'https://secure-firecrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://secure-firecrawl.com', - }), - }), - })) - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-jina-reader-modal.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-jina-reader-modal.spec.tsx deleted file mode 100644 index 26c53993c1..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-jina-reader-modal.spec.tsx +++ /dev/null @@ -1,179 +0,0 @@ -import { render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { DataSourceProvider } from '@/models/common' -import { createDataSourceApiKeyBinding } from '@/service/datasets' -import ConfigJinaReaderModal from '../config-jina-reader-modal' - -/** - * ConfigJinaReaderModal Component Tests - * Tests validation, save logic, and basic rendering for the Jina Reader configuration modal. - */ - -vi.mock('@/service/datasets', () => ({ - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('ConfigJinaReaderModal Component', () => { - const mockOnCancel = vi.fn() - const mockOnSaved = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Initial Rendering', () => { - it('should render the modal with API Key field and buttons', () => { - // Act - render() - - // Assert - expect(screen.getByText('datasetCreation.jinaReader.configJinaReader')).toBeInTheDocument() - expect(screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder')).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.save/i })).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.cancel/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /datasetCreation\.jinaReader\.getApiKeyLinkText/i })).toHaveAttribute('href', 'https://jina.ai/reader/') - }) - }) - - describe('Form Interactions', () => { - it('should update state when API Key field changes', async () => { - const user = userEvent.setup() - // Arrange - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder') - - // Act - await user.type(apiKeyInput, 'jina-test-key') - - // Assert - expect(apiKeyInput).toHaveValue('jina-test-key') - }) - - it('should call onCancel when cancel button is clicked', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - expect(mockOnCancel).toHaveBeenCalled() - }) - }) - - describe('Validation', () => { - it('should show error when saving without API Key', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.fieldRequired:{"field":"API Key"}')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Saving Logic', () => { - it('should save successfully with valid API Key', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder') - - // Act - await user.type(apiKeyInput, 'valid-jina-key') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith({ - category: 'website', - provider: DataSourceProvider.jinaReader, - credentials: { - auth_type: 'bearer', - config: { - api_key: 'valid-jina-key', - }, - }, - }) - }) - await waitFor(() => { - expect(screen.getByText('common.api.success')).toBeInTheDocument() - expect(mockOnSaved).toHaveBeenCalled() - }) - }) - - it('should ignore multiple save clicks while saving is in progress', async () => { - const user = userEvent.setup() - // Arrange - let resolveSave: (value: { result: 'success' }) => void - const savePromise = new Promise<{ result: 'success' }>((resolve) => { - resolveSave = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(savePromise) - render() - await user.type(screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder'), 'test-key') - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - - // Act - await user.click(saveBtn) - await user.click(saveBtn) - - // Assert - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveSave!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1)) - }) - - it('should show encryption info and external link in the modal', async () => { - render() - - // Verify PKCS1_OAEP link exists - const pkcsLink = screen.getByText('PKCS1_OAEP') - expect(pkcsLink.closest('a')).toHaveAttribute('href', 'https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html') - - // Verify the Jina Reader external link - const jinaLink = screen.getByRole('link', { name: /datasetCreation\.jinaReader\.getApiKeyLinkText/i }) - expect(jinaLink).toHaveAttribute('target', '_blank') - }) - - it('should return early when save is clicked while already saving (isSaving guard)', async () => { - const user = userEvent.setup() - // Arrange - a save that never resolves so isSaving stays true - let resolveFirst: (value: { result: 'success' }) => void - const neverResolves = new Promise<{ result: 'success' }>((resolve) => { - resolveFirst = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(neverResolves) - render() - - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder') - await user.type(apiKeyInput, 'valid-key') - - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - // First click - starts saving, isSaving becomes true - await user.click(saveBtn) - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Second click using fireEvent bypasses disabled check - hits isSaving guard - const { fireEvent: fe } = await import('@testing-library/react') - fe.click(saveBtn) - // Still only called once because isSaving=true returns early - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveFirst!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalled()) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-watercrawl-modal.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-watercrawl-modal.spec.tsx deleted file mode 100644 index 6c5961be54..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-watercrawl-modal.spec.tsx +++ /dev/null @@ -1,204 +0,0 @@ -import type { CommonResponse } from '@/models/common' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { createDataSourceApiKeyBinding } from '@/service/datasets' -import ConfigWatercrawlModal from '../config-watercrawl-modal' - -/** - * ConfigWatercrawlModal Component Tests - * Tests validation, save logic, and basic rendering for the Watercrawl configuration modal. - */ - -vi.mock('@/service/datasets', () => ({ - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('ConfigWatercrawlModal Component', () => { - const mockOnCancel = vi.fn() - const mockOnSaved = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Initial Rendering', () => { - it('should render the modal with all fields and buttons', () => { - // Act - render() - - // Assert - expect(screen.getByText('datasetCreation.watercrawl.configWatercrawl')).toBeInTheDocument() - expect(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder')).toBeInTheDocument() - expect(screen.getByPlaceholderText('https://app.watercrawl.dev')).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.save/i })).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.cancel/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /datasetCreation\.watercrawl\.getApiKeyLinkText/i })).toHaveAttribute('href', 'https://app.watercrawl.dev/') - }) - }) - - describe('Form Interactions', () => { - it('should update state when input fields change', async () => { - // Arrange - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder') - const baseUrlInput = screen.getByPlaceholderText('https://app.watercrawl.dev') - - // Act - fireEvent.change(apiKeyInput, { target: { value: 'water-key' } }) - fireEvent.change(baseUrlInput, { target: { value: 'https://custom.watercrawl.dev' } }) - - // Assert - expect(apiKeyInput).toHaveValue('water-key') - expect(baseUrlInput).toHaveValue('https://custom.watercrawl.dev') - }) - - it('should call onCancel when cancel button is clicked', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - expect(mockOnCancel).toHaveBeenCalled() - }) - }) - - describe('Validation', () => { - it('should show error when saving without API Key', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.fieldRequired:{"field":"API Key"}')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - - it('should show error for invalid Base URL format', async () => { - const user = userEvent.setup() - // Arrange - render() - const baseUrlInput = screen.getByPlaceholderText('https://app.watercrawl.dev') - - // Act - await user.type(baseUrlInput, 'ftp://invalid-url.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.urlError')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Saving Logic', () => { - it('should save successfully with valid API Key and custom URL', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'valid-key') - await user.type(screen.getByPlaceholderText('https://app.watercrawl.dev'), 'http://my-watercrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith({ - category: 'website', - provider: 'watercrawl', - credentials: { - auth_type: 'x-api-key', - config: { - api_key: 'valid-key', - base_url: 'http://my-watercrawl.com', - }, - }, - }) - }) - await waitFor(() => { - expect(screen.getByText('common.api.success')).toBeInTheDocument() - expect(mockOnSaved).toHaveBeenCalled() - }) - }) - - it('should use default Base URL if none is provided during save', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'test-api-key') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://app.watercrawl.dev', - }), - }), - })) - }) - }) - - it('should ignore multiple save clicks while saving is in progress', async () => { - const user = userEvent.setup() - // Arrange - let resolveSave: (value: CommonResponse) => void - const savePromise = new Promise((resolve) => { - resolveSave = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(savePromise) - render() - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'test-api-key') - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - - // Act - await user.click(saveBtn) - await user.click(saveBtn) - - // Assert - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveSave!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1)) - }) - - it('should accept base_url starting with https://', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'test-api-key') - await user.type(screen.getByPlaceholderText('https://app.watercrawl.dev'), 'https://secure-watercrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://secure-watercrawl.com', - }), - }), - })) - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/index.spec.tsx deleted file mode 100644 index 1e95cbd087..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/index.spec.tsx +++ /dev/null @@ -1,251 +0,0 @@ -import type { AppContextValue } from '@/context/app-context' -import type { CommonResponse } from '@/models/common' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' - -import { useAppContext } from '@/context/app-context' -import { DataSourceProvider } from '@/models/common' -import { fetchDataSources, removeDataSourceApiKeyBinding } from '@/service/datasets' -import DataSourceWebsite from '../index' - -/** - * DataSourceWebsite Component Tests - * Tests integration of multiple website scraping providers (Firecrawl, WaterCrawl, Jina Reader). - */ - -type DataSourcesResponse = CommonResponse & { - sources: Array<{ id: string, provider: DataSourceProvider }> -} - -// Mock App Context -vi.mock('@/context/app-context', () => ({ - useAppContext: vi.fn(), -})) - -// Mock Service calls -vi.mock('@/service/datasets', () => ({ - fetchDataSources: vi.fn(), - removeDataSourceApiKeyBinding: vi.fn(), - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('DataSourceWebsite Component', () => { - const mockSources = [ - { id: '1', provider: DataSourceProvider.fireCrawl }, - { id: '2', provider: DataSourceProvider.waterCrawl }, - { id: '3', provider: DataSourceProvider.jinaReader }, - ] - - beforeEach(() => { - vi.clearAllMocks() - vi.mocked(useAppContext).mockReturnValue({ isCurrentWorkspaceManager: true } as unknown as AppContextValue) - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [] } as DataSourcesResponse) - }) - - // Helper to render and wait for initial fetch to complete - const renderAndWait = async (provider: DataSourceProvider) => { - const result = render() - await waitFor(() => expect(fetchDataSources).toHaveBeenCalled()) - return result - } - - describe('Data Initialization', () => { - it('should fetch data sources on mount and reflect configured status', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: mockSources } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.fireCrawl) - - // Assert - expect(screen.getByText('common.dataSource.website.configuredCrawlers')).toBeInTheDocument() - }) - - it('should pass readOnly status based on workspace manager permissions', async () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ isCurrentWorkspaceManager: false } as unknown as AppContextValue) - - // Act - await renderAndWait(DataSourceProvider.fireCrawl) - - // Assert - expect(screen.getByText('common.dataSource.configure')).toHaveClass('cursor-default') - }) - }) - - describe('Provider Specific Rendering', () => { - it('should render correct logo and name for Firecrawl', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[0]] } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.fireCrawl) - - // Assert - expect(await screen.findByText('Firecrawl')).toBeInTheDocument() - expect(screen.getByText('🔥')).toBeInTheDocument() - }) - - it('should render correct logo and name for WaterCrawl', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[1]] } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.waterCrawl) - - // Assert - const elements = await screen.findAllByText('WaterCrawl') - expect(elements.length).toBeGreaterThanOrEqual(1) - }) - - it('should render correct logo and name for Jina Reader', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[2]] } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.jinaReader) - - // Assert - const elements = await screen.findAllByText('Jina Reader') - expect(elements.length).toBeGreaterThanOrEqual(1) - }) - }) - - describe('Modal Interactions', () => { - it('should manage opening and closing of configuration modals', async () => { - // Arrange - await renderAndWait(DataSourceProvider.fireCrawl) - - // Act (Open) - fireEvent.click(screen.getByText('common.dataSource.configure')) - // Assert - expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument() - - // Act (Cancel) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - // Assert - expect(screen.queryByText('datasetCreation.firecrawl.configFirecrawl')).not.toBeInTheDocument() - }) - - it('should re-fetch sources after saving configuration (Watercrawl)', async () => { - // Arrange - await renderAndWait(DataSourceProvider.waterCrawl) - fireEvent.click(screen.getByText('common.dataSource.configure')) - vi.mocked(fetchDataSources).mockClear() - - // Act - fireEvent.change(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), { target: { value: 'test-key' } }) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(fetchDataSources).toHaveBeenCalled() - expect(screen.queryByText('datasetCreation.watercrawl.configWatercrawl')).not.toBeInTheDocument() - }) - }) - - it('should re-fetch sources after saving configuration (Jina Reader)', async () => { - // Arrange - await renderAndWait(DataSourceProvider.jinaReader) - fireEvent.click(screen.getByText('common.dataSource.configure')) - vi.mocked(fetchDataSources).mockClear() - - // Act - fireEvent.change(screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder'), { target: { value: 'test-key' } }) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(fetchDataSources).toHaveBeenCalled() - expect(screen.queryByText('datasetCreation.jinaReader.configJinaReader')).not.toBeInTheDocument() - }) - }) - }) - - describe('Management Actions', () => { - it('should handle successful data source removal with toast notification', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[0]] } as DataSourcesResponse) - vi.mocked(removeDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' } as CommonResponse) - await renderAndWait(DataSourceProvider.fireCrawl) - await waitFor(() => expect(screen.getByText('common.dataSource.website.configuredCrawlers')).toBeInTheDocument()) - - // Act - const removeBtn = screen.getByText('Firecrawl').parentElement?.querySelector('svg')?.parentElement - if (removeBtn) - fireEvent.click(removeBtn) - - // Assert - await waitFor(() => { - expect(removeDataSourceApiKeyBinding).toHaveBeenCalledWith('1') - expect(screen.getByText('common.api.remove')).toBeInTheDocument() - }) - expect(screen.queryByText('common.dataSource.website.configuredCrawlers')).not.toBeInTheDocument() - }) - - it('should skip removal API call if no data source ID is present', async () => { - // Arrange - await renderAndWait(DataSourceProvider.fireCrawl) - - // Act - const removeBtn = screen.queryByText('Firecrawl')?.parentElement?.querySelector('svg')?.parentElement - if (removeBtn) - fireEvent.click(removeBtn) - - // Assert - expect(removeDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Firecrawl Save Flow', () => { - it('should re-fetch sources after saving Firecrawl configuration', async () => { - // Arrange - await renderAndWait(DataSourceProvider.fireCrawl) - fireEvent.click(screen.getByText('common.dataSource.configure')) - expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument() - vi.mocked(fetchDataSources).mockClear() - - // Act - fill in required API key field and save - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder') - fireEvent.change(apiKeyInput, { target: { value: 'test-key' } }) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(fetchDataSources).toHaveBeenCalled() - expect(screen.queryByText('datasetCreation.firecrawl.configFirecrawl')).not.toBeInTheDocument() - }) - }) - }) - - describe('Cancel Flow', () => { - it('should close watercrawl modal when cancel is clicked', async () => { - // Arrange - await renderAndWait(DataSourceProvider.waterCrawl) - fireEvent.click(screen.getByText('common.dataSource.configure')) - expect(screen.getByText('datasetCreation.watercrawl.configWatercrawl')).toBeInTheDocument() - - // Act - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - modal closed - await waitFor(() => { - expect(screen.queryByText('datasetCreation.watercrawl.configWatercrawl')).not.toBeInTheDocument() - }) - }) - - it('should close jina reader modal when cancel is clicked', async () => { - // Arrange - await renderAndWait(DataSourceProvider.jinaReader) - fireEvent.click(screen.getByText('common.dataSource.configure')) - expect(screen.getByText('datasetCreation.jinaReader.configJinaReader')).toBeInTheDocument() - - // Act - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - modal closed - await waitFor(() => { - expect(screen.queryByText('datasetCreation.jinaReader.configJinaReader')).not.toBeInTheDocument() - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx deleted file mode 100644 index d7f15236a7..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx +++ /dev/null @@ -1,165 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { FirecrawlConfig } from '@/models/common' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' -import Field from '@/app/components/datasets/create/website/base/field' -import { createDataSourceApiKeyBinding } from '@/service/datasets' - -type Props = { - onCancel: () => void - onSaved: () => void -} - -const I18N_PREFIX = 'firecrawl' - -const DEFAULT_BASE_URL = 'https://api.firecrawl.dev' - -const ConfigFirecrawlModal: FC = ({ - onCancel, - onSaved, -}) => { - const { t } = useTranslation() - const [isSaving, setIsSaving] = useState(false) - const [config, setConfig] = useState({ - api_key: '', - base_url: '', - }) - - const handleConfigChange = useCallback((key: string) => { - return (value: string | number) => { - setConfig(prev => ({ ...prev, [key]: value as string })) - } - }, []) - - const handleSave = useCallback(async () => { - if (isSaving) - return - let errorMsg = '' - if (config.base_url && !((config.base_url.startsWith('http://') || config.base_url.startsWith('https://')))) - errorMsg = t('errorMsg.urlError', { ns: 'common' }) - if (!errorMsg) { - if (!config.api_key) { - errorMsg = t('errorMsg.fieldRequired', { - ns: 'common', - field: 'API Key', - }) - } - } - - if (errorMsg) { - Toast.notify({ - type: 'error', - message: errorMsg, - }) - return - } - const postData = { - category: 'website', - provider: 'firecrawl', - credentials: { - auth_type: 'bearer', - config: { - api_key: config.api_key, - base_url: config.base_url || DEFAULT_BASE_URL, - }, - }, - } - try { - setIsSaving(true) - await createDataSourceApiKeyBinding(postData) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - } - finally { - setIsSaving(false) - } - - onSaved() - }, [config.api_key, config.base_url, onSaved, t, isSaving]) - - return ( - - -
    -
    -
    -
    -
    {t(`${I18N_PREFIX}.configFirecrawl`, { ns: 'datasetCreation' })}
    -
    - -
    - - -
    -
    - - {t(`${I18N_PREFIX}.getApiKeyLinkText`, { ns: 'datasetCreation' })} - - -
    - - -
    - -
    -
    -
    -
    - - {t('modelProvider.encrypted.front', { ns: 'common' })} - - PKCS1_OAEP - - {t('modelProvider.encrypted.back', { ns: 'common' })} -
    -
    -
    -
    -
    -
    - ) -} -export default React.memo(ConfigFirecrawlModal) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx deleted file mode 100644 index 2374ae6174..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx +++ /dev/null @@ -1,144 +0,0 @@ -'use client' -import type { FC } from 'react' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' -import Field from '@/app/components/datasets/create/website/base/field' -import { DataSourceProvider } from '@/models/common' -import { createDataSourceApiKeyBinding } from '@/service/datasets' - -type Props = { - onCancel: () => void - onSaved: () => void -} - -const I18N_PREFIX = 'jinaReader' - -const ConfigJinaReaderModal: FC = ({ - onCancel, - onSaved, -}) => { - const { t } = useTranslation() - const [isSaving, setIsSaving] = useState(false) - const [apiKey, setApiKey] = useState('') - - const handleSave = useCallback(async () => { - if (isSaving) - return - let errorMsg = '' - if (!errorMsg) { - if (!apiKey) { - errorMsg = t('errorMsg.fieldRequired', { - ns: 'common', - field: 'API Key', - }) - } - } - - if (errorMsg) { - Toast.notify({ - type: 'error', - message: errorMsg, - }) - return - } - const postData = { - category: 'website', - provider: DataSourceProvider.jinaReader, - credentials: { - auth_type: 'bearer', - config: { - api_key: apiKey, - }, - }, - } - try { - setIsSaving(true) - await createDataSourceApiKeyBinding(postData) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - } - finally { - setIsSaving(false) - } - - onSaved() - }, [apiKey, onSaved, t, isSaving]) - - return ( - - -
    -
    -
    -
    -
    {t(`${I18N_PREFIX}.configJinaReader`, { ns: 'datasetCreation' })}
    -
    - -
    - setApiKey(value as string)} - placeholder={t(`${I18N_PREFIX}.apiKeyPlaceholder`, { ns: 'datasetCreation' })!} - /> -
    -
    - - {t(`${I18N_PREFIX}.getApiKeyLinkText`, { ns: 'datasetCreation' })} - - -
    - - -
    - -
    -
    -
    -
    - - {t('modelProvider.encrypted.front', { ns: 'common' })} - - PKCS1_OAEP - - {t('modelProvider.encrypted.back', { ns: 'common' })} -
    -
    -
    -
    -
    -
    - ) -} -export default React.memo(ConfigJinaReaderModal) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx deleted file mode 100644 index a9399f25cd..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx +++ /dev/null @@ -1,165 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { WatercrawlConfig } from '@/models/common' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' -import Field from '@/app/components/datasets/create/website/base/field' -import { createDataSourceApiKeyBinding } from '@/service/datasets' - -type Props = { - onCancel: () => void - onSaved: () => void -} - -const I18N_PREFIX = 'watercrawl' - -const DEFAULT_BASE_URL = 'https://app.watercrawl.dev' - -const ConfigWatercrawlModal: FC = ({ - onCancel, - onSaved, -}) => { - const { t } = useTranslation() - const [isSaving, setIsSaving] = useState(false) - const [config, setConfig] = useState({ - api_key: '', - base_url: '', - }) - - const handleConfigChange = useCallback((key: string) => { - return (value: string | number) => { - setConfig(prev => ({ ...prev, [key]: value as string })) - } - }, []) - - const handleSave = useCallback(async () => { - if (isSaving) - return - let errorMsg = '' - if (config.base_url && !((config.base_url.startsWith('http://') || config.base_url.startsWith('https://')))) - errorMsg = t('errorMsg.urlError', { ns: 'common' }) - if (!errorMsg) { - if (!config.api_key) { - errorMsg = t('errorMsg.fieldRequired', { - ns: 'common', - field: 'API Key', - }) - } - } - - if (errorMsg) { - Toast.notify({ - type: 'error', - message: errorMsg, - }) - return - } - const postData = { - category: 'website', - provider: 'watercrawl', - credentials: { - auth_type: 'x-api-key', - config: { - api_key: config.api_key, - base_url: config.base_url || DEFAULT_BASE_URL, - }, - }, - } - try { - setIsSaving(true) - await createDataSourceApiKeyBinding(postData) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - } - finally { - setIsSaving(false) - } - - onSaved() - }, [config.api_key, config.base_url, onSaved, t, isSaving]) - - return ( - - -
    -
    -
    -
    -
    {t(`${I18N_PREFIX}.configWatercrawl`, { ns: 'datasetCreation' })}
    -
    - -
    - - -
    -
    - - {t(`${I18N_PREFIX}.getApiKeyLinkText`, { ns: 'datasetCreation' })} - - -
    - - -
    - -
    -
    -
    -
    - - {t('modelProvider.encrypted.front', { ns: 'common' })} - - PKCS1_OAEP - - {t('modelProvider.encrypted.back', { ns: 'common' })} -
    -
    -
    -
    -
    -
    - ) -} -export default React.memo(ConfigWatercrawlModal) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/index.tsx deleted file mode 100644 index 22bfb4950e..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/index.tsx +++ /dev/null @@ -1,137 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { DataSourceItem } from '@/models/common' -import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' -import s from '@/app/components/datasets/create/website/index.module.css' -import { useAppContext } from '@/context/app-context' -import { DataSourceProvider } from '@/models/common' -import { fetchDataSources, removeDataSourceApiKeyBinding } from '@/service/datasets' -import { cn } from '@/utils/classnames' -import Panel from '../panel' - -import { DataSourceType } from '../panel/types' -import ConfigFirecrawlModal from './config-firecrawl-modal' -import ConfigJinaReaderModal from './config-jina-reader-modal' -import ConfigWatercrawlModal from './config-watercrawl-modal' - -type Props = { - provider: DataSourceProvider -} - -const DataSourceWebsite: FC = ({ provider }) => { - const { t } = useTranslation() - const { isCurrentWorkspaceManager } = useAppContext() - const [sources, setSources] = useState([]) - const checkSetApiKey = useCallback(async () => { - const res = await fetchDataSources() as any - const list = res.sources - setSources(list) - }, []) - - useEffect(() => { - checkSetApiKey() - }, []) - - const [configTarget, setConfigTarget] = useState(null) - const showConfig = useCallback((provider: DataSourceProvider) => { - setConfigTarget(provider) - }, [setConfigTarget]) - - const hideConfig = useCallback(() => { - setConfigTarget(null) - }, [setConfigTarget]) - - const handleAdded = useCallback(() => { - checkSetApiKey() - hideConfig() - }, [checkSetApiKey, hideConfig]) - - const getIdByProvider = (provider: DataSourceProvider): string | undefined => { - const source = sources.find(item => item.provider === provider) - return source?.id - } - - const getProviderName = (provider: DataSourceProvider): string => { - if (provider === DataSourceProvider.fireCrawl) - return 'Firecrawl' - - if (provider === DataSourceProvider.waterCrawl) - return 'WaterCrawl' - - return 'Jina Reader' - } - - const handleRemove = useCallback((provider: DataSourceProvider) => { - return async () => { - const dataSourceId = getIdByProvider(provider) - if (dataSourceId) { - await removeDataSourceApiKeyBinding(dataSourceId) - setSources(sources.filter(item => item.provider !== provider)) - Toast.notify({ - type: 'success', - message: t('api.remove', { ns: 'common' }), - }) - } - } - }, [sources, t]) - - return ( - <> - item.provider === provider) !== undefined} - onConfigure={() => showConfig(provider)} - readOnly={!isCurrentWorkspaceManager} - configuredList={sources.filter(item => item.provider === provider).map(item => ({ - id: item.id, - logo: ({ className }: { className: string }) => { - if (item.provider === DataSourceProvider.fireCrawl) { - return ( -
    - 🔥 -
    - ) - } - - if (item.provider === DataSourceProvider.waterCrawl) { - return ( -
    - -
    - ) - } - return ( -
    - -
    - ) - }, - name: getProviderName(item.provider), - isActive: true, - }))} - onRemove={handleRemove(provider)} - /> - {configTarget === DataSourceProvider.fireCrawl && ( - - )} - {configTarget === DataSourceProvider.waterCrawl && ( - - )} - {configTarget === DataSourceProvider.jinaReader && ( - - )} - - - ) -} -export default React.memo(DataSourceWebsite) diff --git a/web/app/components/header/account-setting/data-source-page/panel/__tests__/config-item.spec.tsx b/web/app/components/header/account-setting/data-source-page/panel/__tests__/config-item.spec.tsx deleted file mode 100644 index 4ad49a8f8f..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/__tests__/config-item.spec.tsx +++ /dev/null @@ -1,213 +0,0 @@ -import type { ConfigItemType } from '../config-item' -import { fireEvent, render, screen } from '@testing-library/react' -import ConfigItem from '../config-item' -import { DataSourceType } from '../types' - -/** - * ConfigItem Component Tests - * Tests rendering of individual configuration items for Notion and Website data sources. - */ - -// Mock Operate component to isolate ConfigItem unit tests. -vi.mock('../../data-source-notion/operate', () => ({ - default: ({ onAuthAgain, payload }: { onAuthAgain: () => void, payload: { id: string, total: number } }) => ( -
    - - {JSON.stringify(payload)} -
    - ), -})) - -describe('ConfigItem Component', () => { - const mockOnRemove = vi.fn() - const mockOnChangeAuthorizedPage = vi.fn() - const MockLogo = (props: React.SVGProps) => - - const baseNotionPayload: ConfigItemType = { - id: 'notion-1', - logo: MockLogo, - name: 'Notion Workspace', - isActive: true, - notionConfig: { total: 5 }, - } - - const baseWebsitePayload: ConfigItemType = { - id: 'website-1', - logo: MockLogo, - name: 'My Website', - isActive: true, - } - - afterEach(() => { - vi.clearAllMocks() - }) - - describe('Notion Configuration', () => { - it('should render active Notion config item with connected status and operator', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByTestId('mock-logo')).toBeInTheDocument() - expect(screen.getByText('Notion Workspace')).toBeInTheDocument() - const statusText = screen.getByText('common.dataSource.notion.connected') - expect(statusText).toHaveClass('text-util-colors-green-green-600') - expect(screen.getByTestId('operate-payload')).toHaveTextContent(JSON.stringify({ id: 'notion-1', total: 5 })) - }) - - it('should render inactive Notion config item with disconnected status', () => { - // Arrange - const inactivePayload = { ...baseNotionPayload, isActive: false } - - // Act - render( - , - ) - - // Assert - const statusText = screen.getByText('common.dataSource.notion.disconnected') - expect(statusText).toHaveClass('text-util-colors-warning-warning-600') - }) - - it('should handle auth action through the Operate component', () => { - // Arrange - render( - , - ) - - // Act - fireEvent.click(screen.getByTestId('operate-auth-btn')) - - // Assert - expect(mockOnChangeAuthorizedPage).toHaveBeenCalled() - }) - - it('should fallback to 0 total if notionConfig is missing', () => { - // Arrange - const payloadNoConfig = { ...baseNotionPayload, notionConfig: undefined } - - // Act - render( - , - ) - - // Assert - expect(screen.getByTestId('operate-payload')).toHaveTextContent(JSON.stringify({ id: 'notion-1', total: 0 })) - }) - - it('should handle missing notionActions safely without crashing', () => { - // Arrange - render( - , - ) - - // Act & Assert - expect(() => fireEvent.click(screen.getByTestId('operate-auth-btn'))).not.toThrow() - }) - }) - - describe('Website Configuration', () => { - it('should render active Website config item and hide operator', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByText('common.dataSource.website.active')).toBeInTheDocument() - expect(screen.queryByTestId('mock-operate')).not.toBeInTheDocument() - }) - - it('should render inactive Website config item', () => { - // Arrange - const inactivePayload = { ...baseWebsitePayload, isActive: false } - - // Act - render( - , - ) - - // Assert - const statusText = screen.getByText('common.dataSource.website.inactive') - expect(statusText).toHaveClass('text-util-colors-warning-warning-600') - }) - - it('should show remove button and trigger onRemove when clicked (not read-only)', () => { - // Arrange - const { container } = render( - , - ) - - // Note: This selector is brittle but necessary since the delete button lacks - // accessible attributes (data-testid, aria-label). Ideally, the component should - // be updated to include proper accessibility attributes. - const deleteBtn = container.querySelector('div[class*="cursor-pointer"]') as HTMLElement - - // Act - fireEvent.click(deleteBtn) - - // Assert - expect(mockOnRemove).toHaveBeenCalled() - }) - - it('should hide remove button in read-only mode', () => { - // Arrange - const { container } = render( - , - ) - - // Assert - const deleteBtn = container.querySelector('div[class*="cursor-pointer"]') - expect(deleteBtn).not.toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/panel/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/panel/__tests__/index.spec.tsx deleted file mode 100644 index d83cdb5360..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/__tests__/index.spec.tsx +++ /dev/null @@ -1,226 +0,0 @@ -import type { ConfigItemType } from '../config-item' -import { fireEvent, render, screen } from '@testing-library/react' -import { DataSourceProvider } from '@/models/common' -import Panel from '../index' -import { DataSourceType } from '../types' - -/** - * Panel Component Tests - * Tests layout, conditional rendering, and interactions for data source panels (Notion and Website). - */ - -vi.mock('../../data-source-notion/operate', () => ({ - default: () =>
    , -})) - -describe('Panel Component', () => { - const onConfigure = vi.fn() - const onRemove = vi.fn() - const mockConfiguredList: ConfigItemType[] = [ - { id: '1', name: 'Item 1', isActive: true, logo: () => null }, - { id: '2', name: 'Item 2', isActive: false, logo: () => null }, - ] - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Notion Panel Rendering', () => { - it('should render Notion panel when not configured and isSupportList is true', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByText('common.dataSource.notion.title')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.description')).toBeInTheDocument() - const connectBtn = screen.getByText('common.dataSource.connect') - expect(connectBtn).toBeInTheDocument() - - // Act - fireEvent.click(connectBtn) - // Assert - expect(onConfigure).toHaveBeenCalled() - }) - - it('should render Notion panel in readOnly mode when not configured', () => { - // Act - render( - , - ) - - // Assert - const connectBtn = screen.getByText('common.dataSource.connect') - expect(connectBtn).toHaveClass('cursor-default opacity-50 grayscale') - }) - - it('should render Notion panel when configured with list of items', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByRole('button', { name: 'common.dataSource.configure' })).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.connectedWorkspace')).toBeInTheDocument() - expect(screen.getByText('Item 1')).toBeInTheDocument() - expect(screen.getByText('Item 2')).toBeInTheDocument() - }) - - it('should hide connect button for Notion if isSupportList is false', () => { - // Act - render( - , - ) - - // Assert - expect(screen.queryByText('common.dataSource.connect')).not.toBeInTheDocument() - }) - - it('should disable Notion configure button in readOnly mode (configured state)', () => { - // Act - render( - , - ) - - // Assert - const btn = screen.getByRole('button', { name: 'common.dataSource.configure' }) - expect(btn).toBeDisabled() - }) - }) - - describe('Website Panel Rendering', () => { - it('should show correct provider names and handle configuration when not configured', () => { - // Arrange - const { rerender } = render( - , - ) - - // Assert Firecrawl - expect(screen.getByText('🔥 Firecrawl')).toBeInTheDocument() - - // Rerender for WaterCrawl - rerender( - , - ) - expect(screen.getByText('WaterCrawl')).toBeInTheDocument() - - // Rerender for Jina Reader - rerender( - , - ) - expect(screen.getByText('Jina Reader')).toBeInTheDocument() - - // Act - const configBtn = screen.getByText('common.dataSource.configure') - fireEvent.click(configBtn) - // Assert - expect(onConfigure).toHaveBeenCalled() - }) - - it('should handle readOnly mode for Website configuration button', () => { - // Act - render( - , - ) - - // Assert - const configBtn = screen.getByText('common.dataSource.configure') - expect(configBtn).toHaveClass('cursor-default opacity-50 grayscale') - - // Act - fireEvent.click(configBtn) - // Assert - expect(onConfigure).not.toHaveBeenCalled() - }) - - it('should render Website panel correctly when configured with crawlers', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByText('common.dataSource.website.configuredCrawlers')).toBeInTheDocument() - expect(screen.getByText('Item 1')).toBeInTheDocument() - expect(screen.getByText('Item 2')).toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx b/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx deleted file mode 100644 index f62c5e147d..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx +++ /dev/null @@ -1,85 +0,0 @@ -'use client' -import type { FC } from 'react' -import { - RiDeleteBinLine, -} from '@remixicon/react' -import { noop } from 'es-toolkit/function' -import * as React from 'react' -import { useTranslation } from 'react-i18next' -import { cn } from '@/utils/classnames' -import Indicator from '../../../indicator' -import Operate from '../data-source-notion/operate' -import s from './style.module.css' -import { DataSourceType } from './types' - -export type ConfigItemType = { - id: string - logo: any - name: string - isActive: boolean - notionConfig?: { - total: number - } -} - -type Props = { - type: DataSourceType - payload: ConfigItemType - onRemove: () => void - notionActions?: { - onChangeAuthorizedPage: () => void - } - readOnly: boolean -} - -const ConfigItem: FC = ({ - type, - payload, - onRemove, - notionActions, - readOnly, -}) => { - const { t } = useTranslation() - const isNotion = type === DataSourceType.notion - const isWebsite = type === DataSourceType.website - const onChangeAuthorizedPage = notionActions?.onChangeAuthorizedPage || noop - - return ( -
    - -
    {payload.name}
    - { - payload.isActive - ? - : - } -
    - { - payload.isActive - ? t(isNotion ? 'dataSource.notion.connected' : 'dataSource.website.active', { ns: 'common' }) - : t(isNotion ? 'dataSource.notion.disconnected' : 'dataSource.website.inactive', { ns: 'common' }) - } -
    -
    - {isNotion && ( - - )} - - { - isWebsite && !readOnly && ( -
    - -
    - ) - } - -
    - ) -} -export default React.memo(ConfigItem) diff --git a/web/app/components/header/account-setting/data-source-page/panel/index.tsx b/web/app/components/header/account-setting/data-source-page/panel/index.tsx deleted file mode 100644 index 0909603ae8..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/index.tsx +++ /dev/null @@ -1,151 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { ConfigItemType } from './config-item' -import { RiAddLine } from '@remixicon/react' -import * as React from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' - -import { DataSourceProvider } from '@/models/common' -import { cn } from '@/utils/classnames' -import ConfigItem from './config-item' -import s from './style.module.css' -import { DataSourceType } from './types' - -type Props = { - type: DataSourceType - provider?: DataSourceProvider - isConfigured: boolean - onConfigure: () => void - readOnly: boolean - isSupportList?: boolean - configuredList: ConfigItemType[] - onRemove: () => void - notionActions?: { - onChangeAuthorizedPage: () => void - } -} - -const Panel: FC = ({ - type, - provider, - isConfigured, - onConfigure, - readOnly, - configuredList, - isSupportList, - onRemove, - notionActions, -}) => { - const { t } = useTranslation() - const isNotion = type === DataSourceType.notion - const isWebsite = type === DataSourceType.website - - const getProviderName = (): string => { - if (provider === DataSourceProvider.fireCrawl) - return '🔥 Firecrawl' - if (provider === DataSourceProvider.waterCrawl) - return 'WaterCrawl' - return 'Jina Reader' - } - - return ( -
    -
    -
    -
    -
    -
    {t(`dataSource.${type}.title`, { ns: 'common' })}
    - {isWebsite && ( -
    - {t('dataSource.website.with', { ns: 'common' })} - {' '} - {getProviderName()} -
    - )} -
    - { - !isConfigured && ( -
    - {t(`dataSource.${type}.description`, { ns: 'common' })} -
    - ) - } -
    - {isNotion && ( - <> - { - isConfigured - ? ( - - ) - : ( - <> - {isSupportList && ( -
    - - {t('dataSource.connect', { ns: 'common' })} -
    - )} - - ) - } - - )} - - {isWebsite && !isConfigured && ( -
    - {t('dataSource.configure', { ns: 'common' })} -
    - )} - -
    - { - isConfigured && ( - <> -
    -
    - {isNotion ? t('dataSource.notion.connectedWorkspace', { ns: 'common' }) : t('dataSource.website.configuredCrawlers', { ns: 'common' })} -
    -
    -
    -
    - { - configuredList.map(item => ( - - )) - } -
    - - ) - } -
    - ) -} -export default React.memo(Panel) diff --git a/web/app/components/header/account-setting/data-source-page/panel/style.module.css b/web/app/components/header/account-setting/data-source-page/panel/style.module.css deleted file mode 100644 index ac9be02205..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/style.module.css +++ /dev/null @@ -1,17 +0,0 @@ -.notion-icon { - background: #ffffff url(../../../assets/notion.svg) center center no-repeat; - background-size: 20px 20px; -} - -.website-icon { - background: #ffffff url(../../../../datasets/create/assets/web.svg) center center no-repeat; - background-size: 20px 20px; -} - -.workspace-item { - box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); -} - -.workspace-item:last-of-type { - margin-bottom: 0; -} diff --git a/web/app/components/header/account-setting/data-source-page/panel/types.ts b/web/app/components/header/account-setting/data-source-page/panel/types.ts deleted file mode 100644 index 345bc10f81..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/types.ts +++ /dev/null @@ -1,4 +0,0 @@ -export enum DataSourceType { - notion = 'notion', - website = 'website', -} diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index 7e77af2e5f..bfceaeb059 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -1,8 +1,9 @@ 'use client' import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' -import { useCallback, useEffect, useRef, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import SearchInput from '@/app/components/base/search-input' +import { ScrollArea } from '@/app/components/base/ui/scroll-area' import BillingPage from '@/app/components/billing/billing-page' import CustomPage from '@/app/components/custom/custom-page' import { @@ -129,20 +130,6 @@ export default function AccountSetting({ ], }, ] - const scrollRef = useRef(null) - const [scrolled, setScrolled] = useState(false) - useEffect(() => { - const targetElement = scrollRef.current - const scrollHandle = (e: Event) => { - const userScrolled = (e.target as HTMLDivElement).scrollTop > 0 - setScrolled(userScrolled) - } - targetElement?.addEventListener('scroll', scrollHandle) - return () => { - targetElement?.removeEventListener('scroll', scrollHandle) - } - }, []) - const activeItem = [...menuItems[0].items, ...menuItems[1].items].find(item => item.key === activeMenu) const [searchValue, setSearchValue] = useState('') @@ -201,7 +188,7 @@ export default function AccountSetting({ }
    -
    +
    -
    -
    + +
    {activeItem?.name} {activeItem?.description && ( @@ -241,7 +234,7 @@ export default function AccountSetting({ {activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && } {activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && }
    -
    +
    diff --git a/web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx b/web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx index fb032ebd62..eafd57ed66 100644 --- a/web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx @@ -61,7 +61,7 @@ vi.mock('@/app/components/base/select', async () => { } }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ refresh: mockRefresh }), })) diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index 5751e88285..6c84a25428 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -2,7 +2,6 @@ import type { Item } from '@/app/components/base/select' import type { Locale } from '@/i18n-config' -import { useRouter } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -12,6 +11,7 @@ import { useAppContext } from '@/context/app-context' import { useLocale } from '@/context/i18n' import { setLocaleOnClient } from '@/i18n-config' import { languages } from '@/i18n-config/language' +import { useRouter } from '@/next/navigation' import { updateUserProfile } from '@/service/common' import { timezones } from '@/utils/timezone' diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/dialog.spec.tsx b/web/app/components/header/account-setting/members-page/edit-workspace-modal/__tests__/dialog.spec.tsx similarity index 97% rename from web/app/components/header/account-setting/members-page/edit-workspace-modal/dialog.spec.tsx rename to web/app/components/header/account-setting/members-page/edit-workspace-modal/__tests__/dialog.spec.tsx index f489d64912..1714c3ceee 100644 --- a/web/app/components/header/account-setting/members-page/edit-workspace-modal/dialog.spec.tsx +++ b/web/app/components/header/account-setting/members-page/edit-workspace-modal/__tests__/dialog.spec.tsx @@ -2,7 +2,7 @@ import type { ReactNode } from 'react' import { render } from '@testing-library/react' import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' -import EditWorkspaceModal from './index' +import EditWorkspaceModal from '../index' type DialogProps = { children: ReactNode diff --git a/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx b/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx index d2aeca1b6c..7de1fbeccb 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx @@ -2,11 +2,15 @@ import type { InvitationResponse } from '@/models/common' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { useProviderContextSelector } from '@/context/provider-context' import { inviteMember } from '@/service/common' import InviteModal from '../index' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + vi.mock('@/context/provider-context', () => ({ useProviderContextSelector: vi.fn(), useProviderContext: vi.fn(() => ({ @@ -14,6 +18,11 @@ vi.mock('@/context/provider-context', () => ({ })), })) vi.mock('@/service/common') +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + }, +})) vi.mock('@/context/i18n', () => ({ useLocale: () => 'en-US', })) @@ -37,7 +46,6 @@ describe('InviteModal', () => { const mockOnCancel = vi.fn() const mockOnSend = vi.fn() const mockRefreshLicenseLimit = vi.fn() - const mockNotify = vi.fn() beforeEach(() => { vi.clearAllMocks() @@ -49,10 +57,11 @@ describe('InviteModal', () => { }) const renderModal = (isEmailSetup = true) => render( - - - , + , ) + const fillEmails = (value: string) => { + fireEvent.change(screen.getByTestId('mock-email-input'), { target: { value } }) + } it('should render invite modal content', async () => { renderModal() @@ -68,12 +77,8 @@ describe('InviteModal', () => { }) it('should enable send button after entering an email', async () => { - const user = userEvent.setup() - renderModal() - - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') expect(screen.getByRole('button', { name: /members\.sendInvite/i })).toBeEnabled() }) @@ -84,7 +89,7 @@ describe('InviteModal', () => { renderModal() - await user.type(screen.getByTestId('mock-email-input'), 'user@example.com') + fillEmails('user@example.com') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) await waitFor(() => { @@ -103,8 +108,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) await waitFor(() => { @@ -116,8 +120,6 @@ describe('InviteModal', () => { }) it('should keep send button disabled when license limit is exceeded', async () => { - const user = userEvent.setup() - vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({ licenseLimit: { workspace_members: { size: 10, limit: 10 } }, refreshLicenseLimit: mockRefreshLicenseLimit, @@ -125,8 +127,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') expect(screen.getByRole('button', { name: /members\.sendInvite/i })).toBeDisabled() }) @@ -144,15 +145,11 @@ describe('InviteModal', () => { const user = userEvent.setup() renderModal() - const input = screen.getByTestId('mock-email-input') // Use an email that passes basic validation but fails our strict regex (needs 2+ char TLD) - await user.type(input, 'invalid@email.c') + fillEmails('invalid@email.c') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'common.members.emailInvalid', - }) + expect(toast.error).toHaveBeenCalledWith('common.members.emailInvalid') expect(inviteMember).not.toHaveBeenCalled() }) @@ -160,8 +157,7 @@ describe('InviteModal', () => { const user = userEvent.setup() renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') expect(screen.getByText('user@example.com')).toBeInTheDocument() @@ -203,7 +199,7 @@ describe('InviteModal', () => { renderModal() - await user.type(screen.getByTestId('mock-email-input'), 'user@example.com') + fillEmails('user@example.com') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) await waitFor(() => { @@ -214,8 +210,6 @@ describe('InviteModal', () => { }) it('should show destructive text color when used size exceeds limit', async () => { - const user = userEvent.setup() - vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({ licenseLimit: { workspace_members: { size: 10, limit: 10 } }, refreshLicenseLimit: mockRefreshLicenseLimit, @@ -223,8 +217,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') // usedSize = 10 + 1 = 11 > limit 10 → destructive color const counter = screen.getByText('11') @@ -241,8 +234,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i }) @@ -264,8 +256,6 @@ describe('InviteModal', () => { }) it('should show destructive color and disable send button when limit is exactly met with one email', async () => { - const user = userEvent.setup() - // size=10, limit=10 - adding 1 email makes usedSize=11 > limit=10 vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({ licenseLimit: { workspace_members: { size: 10, limit: 10 } }, @@ -274,8 +264,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') // isLimitExceeded=true → button is disabled, cannot submit const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i }) @@ -293,8 +282,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i }) @@ -320,11 +308,9 @@ describe('InviteModal', () => { refreshLicenseLimit: mockRefreshLicenseLimit, } as unknown as Parameters[0])) - const user = userEvent.setup() renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') // isLimited=false → no destructive color const counter = screen.getByText('1') diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.module.css b/web/app/components/header/account-setting/members-page/invite-modal/index.module.css deleted file mode 100644 index fbaa1187bd..0000000000 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.module.css +++ /dev/null @@ -1,12 +0,0 @@ -.modal { - padding: 24px 32px !important; - width: 400px !important; -} - -.emailsInput { - background-color: rgb(243 244 246 / var(--tw-bg-opacity)) !important; -} - -.emailBackground { - background-color: white !important; -} diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index 8e4e47e0b8..9b4e9fccdc 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -2,20 +2,17 @@ import type { RoleKey } from './role-selector' import type { InvitationResult } from '@/models/common' import { useBoolean } from 'ahooks' -import { noop } from 'es-toolkit/function' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { ReactMultiEmail } from 'react-multi-email' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' -import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@/app/components/base/ui/dialog' +import { toast } from '@/app/components/base/ui/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useProviderContextSelector } from '@/context/provider-context' import { inviteMember } from '@/service/common' import { cn } from '@/utils/classnames' -import s from './index.module.css' import RoleSelector from './role-selector' import 'react-multi-email/dist/style.css' @@ -34,7 +31,6 @@ const InviteModal = ({ const licenseLimit = useProviderContextSelector(s => s.licenseLimit) const refreshLicenseLimit = useProviderContextSelector(s => s.refreshLicenseLimit) const [emails, setEmails] = useState([]) - const { notify } = useContext(ToastContext) const [isLimited, setIsLimited] = useState(false) const [isLimitExceeded, setIsLimitExceeded] = useState(false) const [usedSize, setUsedSize] = useState(licenseLimit.workspace_members.size ?? 0) @@ -74,21 +70,28 @@ const InviteModal = ({ catch { } } else { - notify({ type: 'error', message: t('members.emailInvalid', { ns: 'common' }) }) + toast.error(t('members.emailInvalid', { ns: 'common' })) } setIsSubmitted() - }, [isLimitExceeded, emails, role, locale, onCancel, onSend, notify, t, isSubmitting, refreshLicenseLimit, setIsSubmitted, setIsSubmitting]) + }, [isLimitExceeded, emails, role, locale, onCancel, onSend, t, isSubmitting, refreshLicenseLimit, setIsSubmitted, setIsSubmitting]) return ( -
    - -
    -
    {t('members.inviteTeamMember', { ns: 'common' })}
    -
    + { + if (!open) + onCancel() + }} + > + + +
    + + {t('members.inviteTeamMember', { ns: 'common' })} +
    {t('members.inviteTeamMemberTip', { ns: 'common' })}
    {!isEmailSetup && ( @@ -152,8 +155,8 @@ const InviteModal = ({ {t('members.sendInvite', { ns: 'common' })}
    - -
    + + ) } diff --git a/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx b/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx index e258884b0f..6383b203d9 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx @@ -1,11 +1,10 @@ import * as React from 'react' -import { useState } from 'react' import { useTranslation } from 'react-i18next' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' import { useProviderContext } from '@/context/provider-context' import { cn } from '@/utils/classnames' @@ -25,115 +24,111 @@ export type RoleSelectorProps = { const RoleSelector = ({ value, onChange }: RoleSelectorProps) => { const { t } = useTranslation() - const [open, setOpen] = useState(false) const { datasetOperatorEnabled } = useProviderContext() + const [open, setOpen] = React.useState(false) return ( - -
    - setOpen(v => !v)} - className="block" - > + +
    {t('members.invitedAsRole', { ns: 'common', role: t(roleI18nKeyMap[value], { ns: 'common' }) })}
    +
    + + +
    { + onChange('normal') + setOpen(false) + }} > -
    {t('members.invitedAsRole', { ns: 'common', role: t(roleI18nKeyMap[value], { ns: 'common' }) })}
    -
    -
    - - -
    -
    -
    { - onChange('normal') - setOpen(false) - }} - > -
    -
    {t('members.normal', { ns: 'common' })}
    -
    {t('members.normalTip', { ns: 'common' })}
    - {value === 'normal' && ( -
    - )} -
    -
    -
    { - onChange('editor') - setOpen(false) - }} - > -
    -
    {t('members.editor', { ns: 'common' })}
    -
    {t('members.editorTip', { ns: 'common' })}
    - {value === 'editor' && ( -
    - )} -
    -
    -
    { - onChange('admin') - setOpen(false) - }} - > -
    -
    {t('members.admin', { ns: 'common' })}
    -
    {t('members.adminTip', { ns: 'common' })}
    - {value === 'admin' && ( -
    - )} -
    -
    - {datasetOperatorEnabled && ( +
    +
    {t('members.normal', { ns: 'common' })}
    +
    {t('members.normalTip', { ns: 'common' })}
    + {value === 'normal' && (
    { - onChange('dataset_operator') - setOpen(false) - }} - > -
    -
    {t('members.datasetOperator', { ns: 'common' })}
    -
    {t('members.datasetOperatorTip', { ns: 'common' })}
    - {value === 'dataset_operator' && ( -
    - )} -
    -
    + data-testid="role-option-check" + className="i-custom-vender-line-general-check absolute left-0 top-0.5 h-4 w-4 text-text-accent" + /> )}
    - -
    - +
    { + onChange('editor') + setOpen(false) + }} + > +
    +
    {t('members.editor', { ns: 'common' })}
    +
    {t('members.editorTip', { ns: 'common' })}
    + {value === 'editor' && ( +
    + )} +
    +
    +
    { + onChange('admin') + setOpen(false) + }} + > +
    +
    {t('members.admin', { ns: 'common' })}
    +
    {t('members.adminTip', { ns: 'common' })}
    + {value === 'admin' && ( +
    + )} +
    +
    + {datasetOperatorEnabled && ( +
    { + onChange('dataset_operator') + setOpen(false) + }} + > +
    +
    {t('members.datasetOperator', { ns: 'common' })}
    +
    {t('members.datasetOperatorTip', { ns: 'common' })}
    + {value === 'dataset_operator' && ( +
    + )} +
    +
    + )} +
    + + ) } diff --git a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx index 389db4a42d..dbabb384a2 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx @@ -1,15 +1,10 @@ import type { InvitationResult } from '@/models/common' -import { XMarkIcon } from '@heroicons/react/24/outline' -import { CheckCircleIcon } from '@heroicons/react/24/solid' -import { RiQuestionLine } from '@remixicon/react' -import { noop } from 'es-toolkit/function' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' -import Modal from '@/app/components/base/modal' -import Tooltip from '@/app/components/base/tooltip' +import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@/app/components/base/ui/dialog' +import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import { IS_CE_EDITION } from '@/config' -import s from './index.module.css' import InvitationLink from './invitation-link' export type SuccessInvitationResult = Extract @@ -29,8 +24,18 @@ const InvitedModal = ({ const failedInvitationResults = useMemo(() => invitationResults?.filter(item => item.status !== 'success') as FailedInvitationResult[], [invitationResults]) return ( -
    - + { + if (!open) + onCancel() + }} + > + +
    - +
    -
    -
    {t('members.invitationSent', { ns: 'common' })}
    + {t('members.invitationSent', { ns: 'common' })} {!IS_CE_EDITION && (
    {t('members.invitationSentTip', { ns: 'common' })}
    )} @@ -54,7 +58,7 @@ const InvitedModal = ({ !!successInvitationResults.length && ( <> -
    {t('members.invitationLink', { ns: 'common' })}
    +
    {t('members.invitationLink', { ns: 'common' })}
    {successInvitationResults.map(item => )} @@ -64,18 +68,23 @@ const InvitedModal = ({ !!failedInvitationResults.length && ( <> -
    {t('members.failedInvitationEmails', { ns: 'common' })}
    +
    {t('members.failedInvitationEmails', { ns: 'common' })}
    { failedInvitationResults.map(item => (
    - -
    - {item.email} - -
    + + + {item.email} +
    +
    + )} + /> + + {item.message} +
    ), @@ -97,8 +106,8 @@ const InvitedModal = ({ {t('members.ok', { ns: 'common' })}
    - -
    +
    +
    ) } diff --git a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx index 8f55660fd8..0c5874c4dc 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx @@ -4,7 +4,7 @@ import copy from 'copy-to-clipboard' import { t } from 'i18next' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' -import Tooltip from '@/app/components/base/tooltip' +import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import s from './index.module.css' type IInvitationLinkProps = { @@ -38,20 +38,28 @@ const InvitationLink = ({
    - -
    {value.url}
    + + {value.url}
    } + /> + + {isCopied ? t('copied', { ns: 'appApi' }) : t('copy', { ns: 'appApi' })} +
    - -
    -
    -
    -
    + + +
    +
    +
    + )} + /> + + {isCopied ? t('copied', { ns: 'appApi' }) : t('copy', { ns: 'appApi' })} +
    diff --git a/web/app/components/header/account-setting/members-page/operation/index.tsx b/web/app/components/header/account-setting/members-page/operation/index.tsx index 35c4676d5f..e2b14b9078 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.tsx @@ -102,7 +102,7 @@ const Operation = ({
    - +
    { diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx index c4f614737a..6a2af9ffdb 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx @@ -1,6 +1,6 @@ import { noop } from 'es-toolkit/function' import * as React from 'react' -import { useState } from 'react' +import { useCallback, useState } from 'react' import { Trans, useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' @@ -36,18 +36,33 @@ const TransferOwnershipModal = ({ onClose, show }: Props) => { const [stepToken, setStepToken] = useState('') const [newOwner, setNewOwner] = useState('') const [isTransfer, setIsTransfer] = useState(false) + const timerIdRef = React.useRef(undefined) + + const retimeCountdown = useCallback((timerId?: number) => { + if (timerIdRef.current !== undefined) + window.clearInterval(timerIdRef.current) + + timerIdRef.current = timerId + }, []) + + React.useEffect(() => { + if (!show) + retimeCountdown() + + return retimeCountdown + }, [retimeCountdown, show]) const startCount = () => { setTime(60) - const timer = setInterval(() => { + retimeCountdown(window.setInterval(() => { setTime((prev) => { - if (prev <= 0) { - clearInterval(timer) + if (prev <= 1) { + retimeCountdown() return 0 } return prev - 1 }) - }, 1000) + }, 1000)) } const sendEmail = async () => { @@ -126,6 +141,7 @@ const TransferOwnershipModal = ({ onClose, show }: Props) => {
    = ({
    - +
    { return ({ children }: { children: ReactNode }) => ( diff --git a/web/app/components/header/account-setting/model-provider-page/derive-model-status.spec.ts b/web/app/components/header/account-setting/model-provider-page/__tests__/derive-model-status.spec.ts similarity index 95% rename from web/app/components/header/account-setting/model-provider-page/derive-model-status.spec.ts rename to web/app/components/header/account-setting/model-provider-page/__tests__/derive-model-status.spec.ts index 1b248e98f2..8ef80e5025 100644 --- a/web/app/components/header/account-setting/model-provider-page/derive-model-status.spec.ts +++ b/web/app/components/header/account-setting/model-provider-page/__tests__/derive-model-status.spec.ts @@ -1,11 +1,11 @@ -import type { Model, ModelItem, ModelProvider } from './declarations' -import type { CredentialPanelState } from './provider-added-card/use-credential-panel-state' +import type { Model, ModelItem, ModelProvider } from '../declarations' +import type { CredentialPanelState } from '../provider-added-card/use-credential-panel-state' import { ConfigurationMethodEnum, ModelStatusEnum, ModelTypeEnum, -} from './declarations' -import { deriveModelStatus } from './derive-model-status' +} from '../declarations' +import { deriveModelStatus } from '../derive-model-status' const createCredentialState = (overrides: Partial = {}): CredentialPanelState => ({ variant: 'credits-active', diff --git a/web/app/components/header/account-setting/model-provider-page/index.non-cloud.spec.tsx b/web/app/components/header/account-setting/model-provider-page/__tests__/index.non-cloud.spec.tsx similarity index 87% rename from web/app/components/header/account-setting/model-provider-page/index.non-cloud.spec.tsx rename to web/app/components/header/account-setting/model-provider-page/__tests__/index.non-cloud.spec.tsx index c543c74472..0fbed45fa6 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.non-cloud.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/__tests__/index.non-cloud.spec.tsx @@ -3,8 +3,8 @@ import { CurrentSystemQuotaTypeEnum, CustomConfigurationStatusEnum, QuotaUnitEnum, -} from './declarations' -import ModelProviderPage from './index' +} from '../declarations' +import ModelProviderPage from '../index' const mockQuotaConfig = { quota_type: CurrentSystemQuotaTypeEnum.free, @@ -42,23 +42,23 @@ vi.mock('@/context/provider-context', () => ({ }), })) -vi.mock('./hooks', () => ({ +vi.mock('../hooks', () => ({ useDefaultModel: () => ({ data: null, isLoading: false }), })) -vi.mock('./provider-added-card', () => ({ +vi.mock('../provider-added-card', () => ({ default: () =>
    , })) -vi.mock('./provider-added-card/quota-panel', () => ({ +vi.mock('../provider-added-card/quota-panel', () => ({ default: () =>
    , })) -vi.mock('./system-model-selector', () => ({ +vi.mock('../system-model-selector', () => ({ default: () =>
    , })) -vi.mock('./install-from-marketplace', () => ({ +vi.mock('../install-from-marketplace', () => ({ default: () =>
    , })) diff --git a/web/app/components/header/account-setting/model-provider-page/__tests__/install-from-marketplace.spec.tsx b/web/app/components/header/account-setting/model-provider-page/__tests__/install-from-marketplace.spec.tsx index 452068e61c..68a705e6c4 100644 --- a/web/app/components/header/account-setting/model-provider-page/__tests__/install-from-marketplace.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/__tests__/install-from-marketplace.spec.tsx @@ -7,7 +7,7 @@ import { useMarketplaceAllPlugins } from '../hooks' import InstallFromMarketplace from '../install-from-marketplace' // Mock dependencies -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: ({ children, href }: { children: React.ReactNode, href: string }) => {children}, })) diff --git a/web/app/components/header/account-setting/model-provider-page/supports-credits.spec.ts b/web/app/components/header/account-setting/model-provider-page/__tests__/supports-credits.spec.ts similarity index 88% rename from web/app/components/header/account-setting/model-provider-page/supports-credits.spec.ts rename to web/app/components/header/account-setting/model-provider-page/__tests__/supports-credits.spec.ts index ef2e79c79b..b8ed478a93 100644 --- a/web/app/components/header/account-setting/model-provider-page/supports-credits.spec.ts +++ b/web/app/components/header/account-setting/model-provider-page/__tests__/supports-credits.spec.ts @@ -1,6 +1,6 @@ -import type { ModelProvider } from './declarations' -import { CurrentSystemQuotaTypeEnum } from './declarations' -import { providerSupportsCredits } from './supports-credits' +import type { ModelProvider } from '../declarations' +import { CurrentSystemQuotaTypeEnum } from '../declarations' +import { providerSupportsCredits } from '../supports-credits' vi.mock('@/config', async (importOriginal) => { const actual = await importOriginal() diff --git a/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx b/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx index ab712f27cc..289e8ce80e 100644 --- a/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx @@ -3,13 +3,13 @@ import type { } from './declarations' import type { Plugin } from '@/app/components/plugins/types' import { useTheme } from 'next-themes' -import Link from 'next/link' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' import ProviderCard from '@/app/components/plugins/provider-card' +import Link from '@/next/link' import { cn } from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' import { diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx index 4025e307f1..536de9bbdf 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx @@ -116,7 +116,7 @@ const AddCustomModel = ({ > {renderTrigger(open)} - +
    { @@ -136,7 +136,7 @@ const AddCustomModel = ({ modelName={model.model} />
    {model.model} @@ -148,7 +148,7 @@ const AddCustomModel = ({ { !notAllowCustomCredential && (
    { handleOpenModalForAddNewCustomModel() setOpen(false) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx index e2f859b09d..15101a6542 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx @@ -164,7 +164,7 @@ const Authorized = ({ > {renderTrigger(mergedIsOpen)} - +
    { popupTitle && ( -
    +
    {popupTitle}
    ) @@ -218,7 +218,7 @@ const Authorized = ({ } : undefined, )} - className="system-xs-medium flex h-[40px] cursor-pointer items-center px-3 text-text-accent-light-mode-only" + className="flex h-[40px] cursor-pointer items-center px-3 text-text-accent-light-mode-only system-xs-medium" > {t('modelProvider.auth.addModelCredential', { ns: 'common' })} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx index 52513e7aeb..dd1d8e6eb9 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx @@ -53,14 +53,14 @@ const CredentialSelector = ({ triggerPopupSameWidth > !disabled && setOpen(v => !v)}> -
    +
    { selectedCredential && (
    { !selectedCredential.addNewCredential && } -
    {selectedCredential.credential_name}
    +
    {selectedCredential.credential_name}
    { selectedCredential.from_enterprise && ( Enterprise @@ -71,13 +71,13 @@ const CredentialSelector = ({ } { !selectedCredential && ( -
    {t('modelProvider.auth.selectModelCredential', { ns: 'common' })}
    +
    {t('modelProvider.auth.selectModelCredential', { ns: 'common' })}
    ) }
    - +
    { @@ -98,7 +98,7 @@ const CredentialSelector = ({ { !notAllowAddNewCredential && (
    diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/dialog.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/__tests__/dialog.spec.tsx similarity index 97% rename from web/app/components/header/account-setting/model-provider-page/model-modal/dialog.spec.tsx rename to web/app/components/header/account-setting/model-provider-page/model-modal/__tests__/dialog.spec.tsx index 9c08560a25..21f8d554c4 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/dialog.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/__tests__/dialog.spec.tsx @@ -1,8 +1,8 @@ import type { ReactNode } from 'react' -import type { Credential, ModelProvider } from '../declarations' +import type { Credential, ModelProvider } from '../../declarations' import { act, render, screen } from '@testing-library/react' -import { ConfigurationMethodEnum, ModelModalModeEnum } from '../declarations' -import ModelModal from './index' +import { ConfigurationMethodEnum, ModelModalModeEnum } from '../../declarations' +import ModelModal from '../index' type DialogProps = { children: ReactNode @@ -27,7 +27,7 @@ vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({ default: () =>
    , })) -vi.mock('../model-auth', () => ({ +vi.mock('../../model-auth', () => ({ CredentialSelector: ({ credentials }: { credentials: Credential[] }) =>
    {`credentials:${credentials.length}`}
    , })) @@ -52,7 +52,7 @@ vi.mock('@/app/components/base/ui/alert-dialog', () => ({ AlertDialogTitle: ({ children }: { children: ReactNode }) =>
    {children}
    , })) -vi.mock('../model-auth/hooks', () => ({ +vi.mock('../../model-auth/hooks', () => ({ useCredentialData: () => ({ isLoading: false, credentialData: { @@ -87,7 +87,7 @@ vi.mock('@/hooks/use-i18n', () => ({ useRenderI18nObject: () => (value: Record) => value[mockLanguage] || value.en_US, })) -vi.mock('../hooks', () => ({ +vi.mock('../../hooks', () => ({ useLanguage: () => mockLanguage, })) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/derive-trigger-status.spec.ts b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/derive-trigger-status.spec.ts similarity index 93% rename from web/app/components/header/account-setting/model-provider-page/model-parameter-modal/derive-trigger-status.spec.ts rename to web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/derive-trigger-status.spec.ts index 828895d35a..3186199524 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/derive-trigger-status.spec.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/derive-trigger-status.spec.ts @@ -1,7 +1,7 @@ -import type { ModelItem, ModelProvider } from '../declarations' -import type { CredentialPanelState } from '../provider-added-card/use-credential-panel-state' -import { ModelStatusEnum } from '../declarations' -import { deriveTriggerStatus } from './derive-trigger-status' +import type { ModelItem, ModelProvider } from '../../declarations' +import type { CredentialPanelState } from '../../provider-added-card/use-credential-panel-state' +import { ModelStatusEnum } from '../../declarations' +import { deriveTriggerStatus } from '../derive-trigger-status' const baseCredentialState: CredentialPanelState = { variant: 'api-active', diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx index 496058bf9b..5c8d5e7489 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx @@ -1,7 +1,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import ModelParameterModal from '../index' -let isAPIKeySet = true let parameterRules: Array> | undefined = [ { name: 'temperature', @@ -40,7 +39,7 @@ let activeTextGenerationModelList: Array> = [ vi.mock('@/context/provider-context', () => ({ useProviderContext: () => ({ - isAPIKeySet, + isAPIKeySet: true, }), })) @@ -50,6 +49,7 @@ vi.mock('@/service/use-common', () => ({ data: parameterRules, }, isLoading: isRulesLoading, + isPending: isRulesLoading, }), })) @@ -62,12 +62,18 @@ vi.mock('../../hooks', () => ({ })) vi.mock('../parameter-item', () => ({ - default: ({ parameterRule, onChange, onSwitch }: { + default: ({ parameterRule, onChange, onSwitch, nodesOutputVars, availableNodes }: { parameterRule: { name: string, label: { en_US: string } } onChange: (v: number) => void onSwitch: (checked: boolean, val: unknown) => void + nodesOutputVars?: unknown[] + availableNodes?: unknown[] }) => ( -
    +
    {parameterRule.label.en_US} @@ -119,7 +125,6 @@ describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - isAPIKeySet = true isRulesLoading = false parameterRules = [ { @@ -233,6 +238,26 @@ describe('ModelParameterModal', () => { expect(screen.getByTestId('model-selector')).toBeInTheDocument() }) + it('should pass nodesOutputVars and availableNodes to ParameterItem', () => { + const mockNodesOutputVars = [{ nodeId: 'n1', title: 'Node', vars: [] }] + const mockAvailableNodes = [{ id: 'n1', data: { title: 'Node', type: 'llm' } }] + + render( + , + ) + + fireEvent.click(screen.getByText('Open Settings')) + + const paramEl = screen.getByTestId('param-temperature') + expect(paramEl).toHaveAttribute('data-has-nodes-output-vars', 'true') + expect(paramEl).toHaveAttribute('data-has-available-nodes', 'true') + }) + it('should support custom triggers, workflow mode, and missing default model values', async () => { render( ({ +vi.mock('../../hooks', () => ({ useLanguage: () => 'en_US', })) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/parameter-item.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/parameter-item.spec.tsx index 5aec68c098..0d684c35f4 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/parameter-item.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/parameter-item.spec.tsx @@ -1,5 +1,10 @@ import type { ModelParameterRule } from '../../declarations' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' import { fireEvent, render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' import ParameterItem from '../parameter-item' vi.mock('../../hooks', () => ({ @@ -18,6 +23,29 @@ vi.mock('@/app/components/base/tag-input', () => ({ ), })) +let promptEditorOnChange: ((text: string) => void) | undefined +let capturedWorkflowNodesMap: Record | undefined + +vi.mock('@/app/components/base/prompt-editor', () => ({ + default: ({ value, onChange, workflowVariableBlock }: { + value: string + onChange: (text: string) => void + workflowVariableBlock?: { + show: boolean + variables: NodeOutPutVar[] + workflowNodesMap?: Record + } + }) => { + promptEditorOnChange = onChange + capturedWorkflowNodesMap = workflowVariableBlock?.workflowNodesMap + return ( +
    + {value} +
    + ) + }, +})) + describe('ParameterItem', () => { const createRule = (overrides: Partial = {}): ModelParameterRule => ({ name: 'temp', @@ -30,9 +58,10 @@ describe('ParameterItem', () => { beforeEach(() => { vi.clearAllMocks() + promptEditorOnChange = undefined + capturedWorkflowNodesMap = undefined }) - // Float tests it('should render float controls and clamp numeric input to max', () => { const onChange = vi.fn() render() @@ -50,7 +79,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(0.1) }) - // Int tests it('should render int controls and clamp numeric input', () => { const onChange = vi.fn() render() @@ -75,22 +103,17 @@ describe('ParameterItem', () => { it('should render int input without slider if min or max is missing', () => { render() expect(screen.queryByRole('slider')).not.toBeInTheDocument() - // No max -> precision step expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0') }) - // Slider events (uses generic value mock for slider) it('should handle slide change and clamp values', () => { const onChange = vi.fn() render() - // Test that the actual slider triggers the onChange logic correctly - // The implementation of Slider uses onChange(val) directly via the mock fireEvent.click(screen.getByTestId('slider-btn')) expect(onChange).toHaveBeenCalledWith(2) }) - // Text & String tests it('should render exact string input and propagate text changes', () => { const onChange = vi.fn() render() @@ -109,21 +132,17 @@ describe('ParameterItem', () => { it('should render select for string with options', () => { render() - // Select renders the selected value in the trigger expect(screen.getByText('a')).toBeInTheDocument() }) - // Tag Tests it('should render tag input for tag type', () => { const onChange = vi.fn() render() expect(screen.getByText('placeholder')).toBeInTheDocument() - // Trigger mock tag input fireEvent.click(screen.getByTestId('tag-input')) expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2']) }) - // Boolean tests it('should render boolean radios and update value on click', () => { const onChange = vi.fn() render() @@ -131,7 +150,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(false) }) - // Switch tests it('should call onSwitch with current value when optional switch is toggled off', () => { const onSwitch = vi.fn() render() @@ -146,7 +164,6 @@ describe('ParameterItem', () => { expect(screen.queryByRole('switch')).not.toBeInTheDocument() }) - // Default Value Fallbacks (rendering without value) it('should use default values if value is undefined', () => { const { rerender } = render() expect(screen.getByRole('spinbutton')).toHaveValue(0.5) @@ -158,26 +175,102 @@ describe('ParameterItem', () => { expect(screen.getByText('True')).toBeInTheDocument() expect(screen.getByText('False')).toBeInTheDocument() - // Without default - rerender() // min is 0 by default in createRule + rerender() expect(screen.getByRole('spinbutton')).toHaveValue(0) }) - // Input Blur it('should reset input to actual bound value on blur', () => { render() const input = screen.getByRole('spinbutton') - // change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state) - // Actually our test fires a change so localValue = 1, then blur sets it fireEvent.change(input, { target: { value: '5' } }) fireEvent.blur(input) expect(input).toHaveValue(1) }) - // Unsupported it('should render no input for unsupported parameter type', () => { render() expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument() }) + + describe('workflow variable reference', () => { + const mockNodesOutputVars: NodeOutPutVar[] = [ + { nodeId: 'node1', title: 'LLM Node', vars: [] }, + ] + const mockAvailableNodes: Node[] = [ + { id: 'node1', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'LLM Node', type: BlockEnum.LLM } } as Node, + { id: 'start', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'Start', type: BlockEnum.Start } } as Node, + ] + + it('should build workflowNodesMap and render PromptEditor for string type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + expect(capturedWorkflowNodesMap!.node1.title).toBe('LLM Node') + expect(capturedWorkflowNodesMap!.sys.title).toBe('workflow.blocks.start') + expect(capturedWorkflowNodesMap!.sys.type).toBe(BlockEnum.Start) + + promptEditorOnChange?.('updated text') + expect(onChange).toHaveBeenCalledWith('updated text') + }) + + it('should build workflowNodesMap and render PromptEditor for text type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + + promptEditorOnChange?.('new long text') + expect(onChange).toHaveBeenCalledWith('new long text') + }) + + it('should fall back to plain input when not in workflow mode for string type', () => { + render( + , + ) + + expect(screen.queryByTestId('prompt-editor')).not.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + + it('should return undefined workflowNodesMap when not in workflow mode', () => { + render( + , + ) + + expect(capturedWorkflowNodesMap).toBeUndefined() + }) + }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index fc10536de8..ccb2c67a0d 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -9,6 +9,10 @@ import type { } from '../declarations' import type { ParameterValue } from './parameter-item' import type { TriggerProps } from './trigger' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' @@ -32,7 +36,6 @@ import Trigger from './trigger' export type ModelParameterModalProps = { popupClassName?: string - portalToFollowElemContentClassName?: string isAdvancedMode: boolean modelId: string provider: string @@ -46,11 +49,12 @@ export type ModelParameterModalProps = { readonly?: boolean isInWorkflow?: boolean scope?: string + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } const ModelParameterModal: FC = ({ popupClassName, - portalToFollowElemContentClassName, isAdvancedMode, modelId, provider, @@ -63,11 +67,18 @@ const ModelParameterModal: FC = ({ renderTrigger, readonly, isInWorkflow, + nodesOutputVars, + availableNodes, }) => { const { t } = useTranslation() const [open, setOpen] = useState(false) const settingsIconRef = useRef(null) - const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId) + const { + data: parameterRulesData, + isPending, + isLoading, + } = useModelParameterRules(provider, modelId) + const isRulesLoading = isPending || isLoading const { currentProvider, currentModel, @@ -161,7 +172,6 @@ const ModelParameterModal: FC = ({ @@ -194,7 +204,7 @@ const ModelParameterModal: FC = ({ }
    { - isLoading + isRulesLoading ?
    : ( [ @@ -208,6 +218,8 @@ const ModelParameterModal: FC = ({ onChange={v => handleParamChange(parameter.name, v)} onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)} isInWorkflow={isInWorkflow} + nodesOutputVars={nodesOutputVars} + availableNodes={availableNodes} /> )) ) @@ -216,7 +228,7 @@ const ModelParameterModal: FC = ({ ) } { - !parameterRules.length && isLoading && ( + !parameterRules.length && isRulesLoading && (
    ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 86fb6d81d0..01e3f45371 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -1,11 +1,18 @@ import type { ModelParameterRule } from '../declarations' -import { useEffect, useRef, useState } from 'react' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' +import { useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import PromptEditor from '@/app/components/base/prompt-editor' import Radio from '@/app/components/base/radio' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import TagInput from '@/app/components/base/tag-input' import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' +import { BlockEnum } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' import { useLanguage } from '../hooks' import { isNullOrUndefined } from '../utils' @@ -18,18 +25,43 @@ type ParameterItemProps = { onChange?: (value: ParameterValue) => void onSwitch?: (checked: boolean, assignValue: ParameterValue) => void isInWorkflow?: boolean + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } + function ParameterItem({ parameterRule, value, onChange, onSwitch, isInWorkflow, + nodesOutputVars, + availableNodes = [], }: ParameterItemProps) { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) + const workflowNodesMap = useMemo(() => { + if (!isInWorkflow || !availableNodes.length) + return undefined + + return availableNodes.reduce>>((acc, node) => { + acc[node.id] = { + title: node.data.title, + type: node.data.type, + } + if (node.data.type === BlockEnum.Start) { + acc.sys = { + title: t('blocks.start', { ns: 'workflow' }), + type: BlockEnum.Start, + } + } + return acc + }, {}) + }, [availableNodes, isInWorkflow, t]) + const getDefaultValue = () => { let defaultValue: ParameterValue @@ -196,6 +228,25 @@ function ParameterItem({ } if (parameterRule.type === 'string' && !parameterRule.options?.length) { + if (isInWorkflow && nodesOutputVars) { + return ( +
    + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
    + ) + } + return ( + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
    + ) + } + return (