diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 0000000000..15c697730a --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: true + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1bb7d06232..3f53811f85 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,7 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost +/api/graphon/model_runtime/ @laipz8200 @WH-2099 # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 6f3b3c08b4..24af948732 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,10 +4,9 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0 + uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 with: - node-version-file: web/.nvmrc + working-directory: web + node-version-file: .nvmrc cache: true - cache-dependency-path: web/pnpm-lock.yaml - run-install: | - cwd: ./web + run-install: true diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 6b87946221..c1da73b5df 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -25,7 +25,6 @@ jobs: strategy: matrix: python-version: - - "3.11" - "3.12" steps: diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index be6186980e..d8a53c9594 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -2,6 +2,9 @@ name: autofix.ci on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: @@ -12,9 +15,15 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "autofix.ci updates pull request branches, not merge group refs." + + - if: github.event_name != 'merge_group' + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check Docker Compose inputs + if: github.event_name != 'merge_group' id: docker-compose-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: @@ -24,30 +33,34 @@ jobs: docker/docker-compose-template.yaml docker/docker-compose.yaml - name: Check web inputs + if: github.event_name != 'merge_group' id: web-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** - name: Check api inputs + if: github.event_name != 'merge_group' id: api-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | api/** - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + - if: github.event_name != 'merge_group' + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" - - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + - if: github.event_name != 'merge_group' + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Generate Docker Compose - if: steps.docker-compose-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' run: | cd docker ./generate_docker_compose - - if: steps.api-changes.outputs.any_changed == 'true' + - if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api uv sync --dev @@ -59,13 +72,13 @@ jobs: uv run ruff format .. - name: count migration progress - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api ./cnt_base.sh - name: ast-grep - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | # ast-grep exits 1 if no matches are found; allow idempotent runs. uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true @@ -95,13 +108,14 @@ jobs: find . -name "*.py.bak" -type f -delete - name: Setup web environment - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' uses: ./.github/actions/setup-web - name: ESLint autofix - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + - if: github.event_name != 'merge_group' + uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 69023c24cc..6fffbefce0 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -3,10 +3,14 @@ name: Main CI Pipeline on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: + actions: write contents: write pull-requests: write checks: write @@ -17,9 +21,24 @@ concurrency: cancel-in-progress: true jobs: + pre_job: + name: Skip Duplicate Checks + runs-on: ubuntu-latest + outputs: + should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} + steps: + - id: skip_check + continue-on-error: true + uses: fkirc/skip-duplicate-actions@f75f66ce1886f00957d99748a42c724f4330bdcf # v5.3.1 + with: + cancel_others: 'true' + concurrent_skipping: same_content_newer + # Check which paths were changed to determine which tests to run check-changes: name: Check Changed Files + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' runs-on: ubuntu-latest outputs: api-changed: ${{ steps.changes.outputs.api }} @@ -50,33 +69,247 @@ jobs: - 'api/migrations/**' - '.github/workflows/db-migration-test.yml' - # Run tests in parallel - api-tests: - name: API Tests - needs: check-changes - if: needs.check-changes.outputs.api-changed == 'true' + # Run tests in parallel while always emitting stable required checks. + api-tests-run: + name: Run API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed == 'true' uses: ./.github/workflows/api-tests.yml secrets: inherit - web-tests: - name: Web Tests - needs: check-changes - if: needs.check-changes.outputs.web-changed == 'true' + api-tests-skip: + name: Skip API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped API tests + run: echo "No API-related changes detected; skipping API tests." + + api-tests: + name: API Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - api-tests-run + - api-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize API Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.api-changed }} + RUN_RESULT: ${{ needs.api-tests-run.result }} + SKIP_RESULT: ${{ needs.api-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "API tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "API tests ran successfully." + exit 0 + fi + + echo "API tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "API tests were skipped because no API-related files changed." + exit 0 + fi + + echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-tests-run: + name: Run Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml secrets: inherit + web-tests-skip: + name: Skip Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web tests + run: echo "No web-related changes detected; skipping web tests." + + web-tests: + name: Web Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-tests-run + - web-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.web-changed }} + RUN_RESULT: ${{ needs.web-tests-run.result }} + SKIP_RESULT: ${{ needs.web-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web tests ran successfully." + exit 0 + fi + + echo "Web tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web tests were skipped because no web-related files changed." + exit 0 + fi + + echo "Web tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + style-check: name: Style Check + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' uses: ./.github/workflows/style.yml + vdb-tests-run: + name: Run VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed == 'true' + uses: ./.github/workflows/vdb-tests.yml + + vdb-tests-skip: + name: Skip VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped VDB tests + run: echo "No VDB-related changes detected; skipping VDB tests." + vdb-tests: name: VDB Tests - needs: check-changes - if: needs.check-changes.outputs.vdb-changed == 'true' - uses: ./.github/workflows/vdb-tests.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - vdb-tests-run + - vdb-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize VDB Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.vdb-changed }} + RUN_RESULT: ${{ needs.vdb-tests-run.result }} + SKIP_RESULT: ${{ needs.vdb-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "VDB tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "VDB tests ran successfully." + exit 0 + fi + + echo "VDB tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "VDB tests were skipped because no VDB-related files changed." + exit 0 + fi + + echo "VDB tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + db-migration-test-run: + name: Run DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed == 'true' + uses: ./.github/workflows/db-migration-test.yml + + db-migration-test-skip: + name: Skip DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped DB migration tests + run: echo "No migration-related changes detected; skipping DB migration tests." db-migration-test: name: DB Migration Test - needs: check-changes - if: needs.check-changes.outputs.migration-changed == 'true' - uses: ./.github/workflows/db-migration-test.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - db-migration-test-run + - db-migration-test-skip + runs-on: ubuntu-latest + steps: + - name: Finalize DB Migration Test status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.migration-changed }} + RUN_RESULT: ${{ needs.db-migration-test-run.result }} + SKIP_RESULT: ${{ needs.db-migration-test-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "DB migration tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "DB migration tests ran successfully." + exit 0 + fi + + echo "DB migration tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "DB migration tests were skipped because no migration-related files changed." + exit 0 + fi + + echo "DB migration tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index c21331ec0d..49d2e94695 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -7,6 +7,9 @@ on: - edited - reopened - synchronize + merge_group: + branches: ["main"] + types: [checks_requested] jobs: lint: @@ -15,7 +18,11 @@ jobs: pull-requests: read runs-on: ubuntu-latest steps: + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "Semantic PR title validation is handled on pull requests." - name: Check title + if: github.event_name == 'pull_request' uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 657a481f74..23ae36f7b1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -84,20 +84,20 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Restore ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' + id: eslint-cache-restore + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: | - vp run lint:ci - # pnpm run lint:report - # continue-on-error: true - - # - name: Annotate Code - # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' - # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae - # with: - # eslint-report: web/eslint_report.json - # github-token: ${{ secrets.GITHUB_TOKEN }} + run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' @@ -114,6 +114,13 @@ jobs: working-directory: ./web run: vp run knip + - name: Save ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 84f8000a01..1869254295 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -120,7 +120,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76 + uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index f45f2137d6..7c4cd0ba8c 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -14,7 +14,6 @@ jobs: strategy: matrix: python-version: - - "3.11" - "3.12" steps: diff --git a/api/.env.example b/api/.env.example index 40e1c2dfdf..9672a99d55 100644 --- a/api/.env.example +++ b/api/.env.example @@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # Upstash configuration UPSTASH_VECTOR_URL=your-server-url diff --git a/api/.importlinter b/api/.importlinter index a836d09088..c2841f64d2 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,10 +1,14 @@ [importlinter] root_packages = core - dify_graph + constants + context + graphon configs controllers extensions + factories + libs models tasks services @@ -22,40 +26,30 @@ layers = runtime entities containers = - dify_graph + graphon ignore_imports = - dify_graph.nodes.base.node -> dify_graph.graph_events - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events - dify_graph.nodes.loop.loop_node -> dify_graph.graph_events + graphon.nodes.base.node -> graphon.graph_events + graphon.nodes.iteration.iteration_node -> graphon.graph_events + graphon.nodes.loop.loop_node -> graphon.graph_events - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine - dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine + graphon.nodes.iteration.iteration_node -> graphon.graph_engine + graphon.nodes.loop.loop_node -> graphon.graph_engine # TODO(QuantumGhost): fix the import violation later - dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities - -[importlinter:contract:workflow-infrastructure-dependencies] -name = Workflow Infrastructure Dependencies -type = forbidden -source_modules = - dify_graph -forbidden_modules = - extensions.ext_database - extensions.ext_redis -allow_indirect_imports = True -ignore_imports = - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + graphon.entities.pause_reason -> graphon.nodes.human_input.entities [importlinter:contract:workflow-external-imports] name = Workflow External Imports type = forbidden source_modules = - dify_graph + graphon forbidden_modules = + constants configs + context controllers extensions + factories + libs models services tasks @@ -88,46 +82,14 @@ forbidden_modules = core.tools core.trigger core.variables -ignore_imports = - dify_graph.nodes.llm.llm_utils -> core.model_manager - dify_graph.nodes.llm.protocols -> core.model_manager - dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.llm.node -> core.tools.signature - dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - dify_graph.nodes.tool.tool_node -> core.tools.tool_engine - dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager - dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output - dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.llm.file_saver -> core.tools.signature - dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager - dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.tool.tool_node -> services - dify_graph.model_runtime.model_providers.__base.ai_model -> configs - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.__base.large_language_model -> configs - dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - dify_graph.model_runtime.model_providers.model_provider_factory -> configs - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids + +[importlinter:contract:workflow-third-party-imports] +name = Workflow Third-Party Imports +type = forbidden +source_modules = + graphon +forbidden_modules = + sqlalchemy [importlinter:contract:rsc] name = RSC @@ -136,7 +98,7 @@ layers = graph_engine response_coordinator containers = - dify_graph.graph_engine + graphon.graph_engine [importlinter:contract:worker] name = Worker @@ -145,7 +107,7 @@ layers = graph_engine worker containers = - dify_graph.graph_engine + graphon.graph_engine [importlinter:contract:graph-engine-architecture] name = Graph Engine Architecture @@ -161,28 +123,28 @@ layers = worker_management domain containers = - dify_graph.graph_engine + graphon.graph_engine [importlinter:contract:domain-isolation] name = Domain Model Isolation type = forbidden source_modules = - dify_graph.graph_engine.domain + graphon.graph_engine.domain forbidden_modules = - dify_graph.graph_engine.worker_management - dify_graph.graph_engine.command_channels - dify_graph.graph_engine.layers - dify_graph.graph_engine.protocols + graphon.graph_engine.worker_management + graphon.graph_engine.command_channels + graphon.graph_engine.layers + graphon.graph_engine.protocols [importlinter:contract:worker-management] name = Worker Management type = forbidden source_modules = - dify_graph.graph_engine.worker_management + graphon.graph_engine.worker_management forbidden_modules = - dify_graph.graph_engine.orchestration - dify_graph.graph_engine.command_processing - dify_graph.graph_engine.event_management + graphon.graph_engine.orchestration + graphon.graph_engine.command_processing + graphon.graph_engine.event_management [importlinter:contract:graph-traversal-components] @@ -192,11 +154,11 @@ layers = edge_processor skip_propagator containers = - dify_graph.graph_engine.graph_traversal + graphon.graph_engine.graph_traversal [importlinter:contract:command-channels] name = Command Channels Independence type = independence modules = - dify_graph.graph_engine.command_channels.in_memory_channel - dify_graph.graph_engine.command_channels.redis_channel + graphon.graph_engine.command_channels.in_memory_channel + graphon.graph_engine.command_channels.redis_channel diff --git a/api/.ruff.toml b/api/.ruff.toml index b0947eb619..4b1252a861 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -100,7 +100,7 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"] +"graphon/model_runtime/callbacks/base_callback.py" = ["T201"] "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name diff --git a/api/app_factory.py b/api/app_factory.py index 066eb2ae2c..76838f9925 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/commands/vector.py b/api/commands/vector.py index 4cf11c9ad1..cb7eb7c452 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,6 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -85,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -177,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -269,7 +272,7 @@ def migrate_knowledge_vector_database(): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == "hierarchical_model": + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b..831f0a49e0 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index f8447c6979..8a6a921a4e 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings): ENTERPRISE_REQUEST_TIMEOUT: int = Field( ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_OTLP_API_KEY: str = Field( + description="Bearer token for enterprise OTLP export authentication.", + default="", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + # Setting the default value to False to avoid accidentally log PII data in traces. + default=False, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ge=0.0, + le=1.0, + ) diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 8f956745b1..c8e4f7309f 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings): description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)", default="COARSE_MODE", ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field( + description="Auto build row count increment threshold (default is 500)", + default=500, + ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field( + description="Auto build row count increment ratio threshold (default is 0.05)", + default=0.05, + ) + + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field( + description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)", + default=300, + ) diff --git a/api/context/__init__.py b/api/context/__init__.py index 969e5f583d..8df37138e8 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -1,74 +1,36 @@ """ -Core Context - Framework-agnostic context management. +Application-layer context adapters. -This module provides context management that is independent of any specific -web framework. Framework-specific implementations register their context -capture functions at application initialization time. - -This ensures the workflow layer remains completely decoupled from Flask -or any other web framework. +Concrete execution-context implementations live here so `graphon` only +depends on injected context managers rather than framework state capture. """ -import contextvars -from collections.abc import Callable - -from dify_graph.context.execution_context import ( +from context.execution_context import ( + AppContext, + ContextProviderNotFoundError, ExecutionContext, + ExecutionContextBuilder, IExecutionContext, NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) - -# Global capturer function - set by framework-specific modules -_capturer: Callable[[], IExecutionContext] | None = None - - -def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """ - Register a context capture function. - - This should be called by framework-specific modules (e.g., Flask) - during application initialization. - - Args: - capturer: Function that captures current context and returns IExecutionContext - """ - global _capturer - _capturer = capturer - - -def capture_current_context() -> IExecutionContext: - """ - Capture current execution context. - - This function uses the registered context capturer. If no capturer - is registered, it returns a minimal context with only contextvars - (suitable for non-framework environments like tests or standalone scripts). - - Returns: - IExecutionContext with captured context - """ - if _capturer is None: - # No framework registered - return minimal context - return ExecutionContext( - app_context=NullAppContext(), - context_vars=contextvars.copy_context(), - ) - - return _capturer() - - -def reset_context_provider() -> None: - """ - Reset the context capturer. - - This is primarily useful for testing to ensure a clean state. - """ - global _capturer - _capturer = None - +from context.models import SandboxContext __all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", "register_context_capturer", "reset_context_provider", ] diff --git a/api/dify_graph/context/execution_context.py b/api/context/execution_context.py similarity index 60% rename from api/dify_graph/context/execution_context.py rename to api/context/execution_context.py index e3007530f0..ba9a24d4f3 100644 --- a/api/dify_graph/context/execution_context.py +++ b/api/context/execution_context.py @@ -1,5 +1,8 @@ """ -Execution Context - Abstracted context management for workflow execution. +Application-layer execution context adapters. + +Concrete context capture lives outside `graphon` so the graph package only +consumes injected context managers when it needs to preserve thread-local state. """ import contextvars @@ -16,33 +19,33 @@ class AppContext(ABC): """ Abstract application context interface. - This abstraction allows workflow execution to work with or without Flask - by providing a common interface for application context management. + Application adapters can implement this to restore framework-specific state + such as Flask app context around worker execution. """ @abstractmethod def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" - pass + raise NotImplementedError @abstractmethod def get_extension(self, name: str) -> Any: - """Get Flask extension by name (e.g., 'db', 'cache').""" - pass + """Get application extension by name.""" + raise NotImplementedError @abstractmethod def enter(self) -> AbstractContextManager[None]: """Enter the application context.""" - pass + raise NotImplementedError @runtime_checkable class IExecutionContext(Protocol): """ - Protocol for execution context. + Protocol for enterable execution context objects. - This protocol defines the interface that all execution contexts must implement, - allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + Concrete implementations may carry extra framework state, but callers only + depend on standard context-manager behavior plus optional user metadata. """ def __enter__(self) -> "IExecutionContext": @@ -62,14 +65,10 @@ class IExecutionContext(Protocol): @final class ExecutionContext: """ - Execution context for workflow execution in worker threads. + Generic execution context used by application-layer adapters. - This class encapsulates all context needed for workflow execution: - - Application context (Flask app or standalone) - - Context variables for Python contextvars - - User information (optional) - - It is designed to be serializable and passable to worker threads. + It restores captured `contextvars` and optionally enters an application + context before the worker executes graph logic. """ def __init__( @@ -78,14 +77,6 @@ class ExecutionContext: context_vars: contextvars.Context | None = None, user: Any = None, ) -> None: - """ - Initialize execution context. - - Args: - app_context: Application context (Flask or standalone) - context_vars: Python contextvars to preserve - user: User object (optional) - """ self._app_context = app_context self._context_vars = context_vars self._user = user @@ -98,27 +89,21 @@ class ExecutionContext: @property def context_vars(self) -> contextvars.Context | None: - """Get context variables.""" + """Get captured context variables.""" return self._context_vars @property def user(self) -> Any: - """Get user object.""" + """Get captured user object.""" return self._user @contextmanager def enter(self) -> Generator[None, None, None]: - """ - Enter this execution context. - - This is a convenience method that creates a context manager. - """ - # Restore context variables if provided + """Enter this execution context.""" if self._context_vars: for var, val in self._context_vars.items(): var.set(val) - # Enter app context if available if self._app_context is not None: with self._app_context.enter(): yield @@ -141,18 +126,10 @@ class ExecutionContext: class NullAppContext(AppContext): """ - Null implementation of AppContext for non-Flask environments. - - This is used when running without Flask (e.g., in tests or standalone mode). + Null application context for non-framework environments. """ def __init__(self, config: dict[str, Any] | None = None) -> None: - """ - Initialize null app context. - - Args: - config: Optional configuration dictionary - """ self._config = config or {} self._extensions: dict[str, Any] = {} @@ -165,7 +142,7 @@ class NullAppContext(AppContext): return self._extensions.get(name) def set_extension(self, name: str, extension: Any) -> None: - """Set extension by name.""" + """Register an extension for tests or standalone execution.""" self._extensions[name] = extension @contextmanager @@ -176,9 +153,7 @@ class NullAppContext(AppContext): class ExecutionContextBuilder: """ - Builder for creating ExecutionContext instances. - - This provides a fluent API for building execution contexts. + Builder for creating `ExecutionContext` instances. """ def __init__(self) -> None: @@ -211,63 +186,42 @@ class ExecutionContextBuilder: _capturer: Callable[[], IExecutionContext] | None = None - -# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. -# Key mapping: -# (name, tenant_id) -> provider -# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") -# - tenant_id: tenant identifier string -# Value: -# provider: Callable[[], BaseModel] returning the typed context value -# Type-safety note: -# - This registry cannot enforce that all providers for a given name return the same BaseModel type. -# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), -# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and -# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} T = TypeVar("T", bound=BaseModel) class ContextProviderNotFoundError(KeyError): - """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + """Raised when a tenant-scoped context provider is missing.""" pass def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """Register a single enterable execution context capturer (e.g., Flask).""" + """Register an enterable execution context capturer.""" global _capturer _capturer = capturer def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: - """Register a tenant-specific provider for a named context. - - Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. - Consider adding a typed wrapper for this registration in your feature module. - """ + """Register a tenant-specific provider for a named context.""" _tenant_context_providers[(name, tenant_id)] = provider def read_context(name: str, *, tenant_id: str) -> BaseModel: - """ - Read a context value for a specific tenant. - - Raises KeyError if the provider for (name, tenant_id) is not registered. - """ - prov = _tenant_context_providers.get((name, tenant_id)) - if prov is None: + """Read a context value for a specific tenant.""" + provider = _tenant_context_providers.get((name, tenant_id)) + if provider is None: raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") - return prov() + return provider() def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal - context with NullAppContext + copy of current contextvars. + If no framework adapter is registered, return a minimal context that only + restores `contextvars`. """ if _capturer is None: return ExecutionContext( @@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext: def reset_context_provider() -> None: - """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + """Reset the capturer and tenant-scoped providers.""" global _capturer _capturer = None _tenant_context_providers.clear() + + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 324a9ee8b4..eddd6448d8 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,11 +10,7 @@ from typing import Any, final from flask import Flask, current_app, g -from dify_graph.context import register_context_capturer -from dify_graph.context.execution_context import ( - AppContext, - IExecutionContext, -) +from context.execution_context import AppContext, IExecutionContext, register_context_capturer @final diff --git a/api/dify_graph/context/models.py b/api/context/models.py similarity index 100% rename from api/dify_graph/context/models.py rename to api/context/models.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index c52dcf8a57..764f9f8ee2 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( - ContextVar("plugin_model_providers") -) - -plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( - ContextVar("plugin_model_providers_lock") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index ff5326dade..515a6a5125 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -4,7 +4,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, computed_field -from dify_graph.file import helpers as file_helpers +from graphon.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 6c54be84a8..783cb5c444 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset +from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache @@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None @@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource): ) key = ApiToken.generate_api_key(self.token_prefix or "", 24) + assert self.resource_type is not None, "resource_type must be set" api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_tenant_id @@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None @@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" token_prefix = "app-" @@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource): """Delete an API key for an app""" return super().delete(resource_id, api_key_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" @@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" token_prefix = "ds-" @@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6..357697ed30 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -26,9 +26,9 @@ from controllers.console.wraps import ( from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 2c5e8d29ee..91fbe4a85a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -22,7 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 4d7ddfea13..fe274e4c9a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -26,7 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 74750981dd..d329d22309 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -458,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bb..c720a5e074 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -18,8 +18,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 4b20418b53..412fc8795a 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,7 +99,7 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 4fb73f61f3..dc752939ae 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -4,7 +4,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -24,9 +24,9 @@ from controllers.console.wraps import ( ) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.raws import FilesContainedField +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required @@ -244,27 +244,25 @@ class ChatMessageListApi(Resource): def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) - .first() + .limit(1) ) if not conversation: raise NotFound("Conversation Not Exists.") if args.first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == args.first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1) ) if not first_message: raise NotFound("First message not found") - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -272,16 +270,14 @@ class ChatMessageListApi(Resource): ) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() # Initialize has_more based on whether we have a full page if len(history_messages) == args.limit: @@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource): message_id = str(args.message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource): @login_required @account_initialization_required def get(self, app_model): - count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() + count = db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id) + ) return {"count": count} @@ -479,7 +479,9 @@ class MessageApi(Resource): def get(self, app_model, message_id: str): message_id = str(message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb51..8bb5aa2c1b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict @@ -90,6 +88,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -129,6 +128,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) except Exception: continue diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b81..7f44a99ff1 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index d59aa44718..2737dd1dfd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.helper.trace_id_helper import get_external_trace_id from core.plugin.impl.exc import PluginInvokeError from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE @@ -29,15 +30,15 @@ from core.trigger.debug.event_selectors import ( create_event_poller, select_trigger_debug_events, ) -from dify_graph.enums import NodeType -from dify_graph.file.models import File -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from graphon.enums import NodeType +from graphon.file.models import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" @@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence mappings=files, tenant_id=workflow.tenant_id, config=file_extra_config, + access_controller=_file_access_controller, ) return file_objs diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 9b148c3f18..8cf0004b09 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -9,12 +9,12 @@ from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, build_workflow_archived_log_pagination_model, ) +from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index b78d97a382..657b072490 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,14 +15,15 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.file import helpers as file_helpers -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable @@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -389,13 +391,21 @@ class VariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 7ac653395e..d1df722729 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, cast +from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -12,8 +12,7 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus +from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -27,6 +26,8 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value @@ -172,6 +173,23 @@ console_ns.schema_model( ) +class HumanInputPauseTypeResponse(TypedDict): + type: Literal["human_input"] + form_id: str + backstage_input_url: str | None + + +class PausedNodeResponse(TypedDict): + node_id: str + node_title: str + pause_type: HumanInputPauseTypeResponse + + +class WorkflowPauseDetailsResponse(TypedDict): + paused_at: str | None + paused_nodes: list[PausedNodeResponse] + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @@ -489,18 +507,22 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - return { + empty_response: WorkflowPauseDetailsResponse = { "paused_at": None, "paused_nodes": [], - }, 200 + } + return empty_response, 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + form_tokens_by_form_id = _load_form_tokens_by_form_id( + [reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)] + ) # Build response paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { + paused_nodes: list[PausedNodeResponse] = [] + response: WorkflowPauseDetailsResponse = { "paused_at": paused_at.isoformat() + "Z" if paused_at else None, "paused_nodes": paused_nodes, } @@ -514,7 +536,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): "pause_type": { "type": "human_input", "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), + "backstage_input_url": _build_backstage_input_url( + form_tokens_by_form_id.get(reason.form_id) + ), }, } ) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa..493022ffea 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index c2a95ddad2..9e7faa09c5 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 1ed931b0d7..844f3c91ff 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,7 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( @@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: @@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() - session.commit() # Create workspace if needed if ( diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 112e152432..5c7011fd22 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,9 +1,10 @@ import logging +import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -112,6 +113,9 @@ class OAuthCallback(Resource): error_text = e.response.text logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) return {"error": "OAuth process failed"}, 400 + except ValueError as e: + logger.warning("OAuth error with %s", provider, exc_info=True) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}") if invite_token and RegisterService.is_valid_invite_token(invite_token): invitation = RegisterService.get_invitation_by_token(token=invite_token) @@ -176,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6e59d4203c..665a80802d 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 725a8380cd..5d704b6224 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -3,7 +3,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select +from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -25,12 +25,12 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( @@ -51,10 +51,11 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum -from models.enums import SegmentStatus +from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -331,7 +332,7 @@ class DatasetListApi(Resource): ) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -355,7 +356,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -436,7 +437,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -445,7 +446,7 @@ class DatasetApi(Resource): data.update({"partial_member_list": part_users_list}) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -454,7 +455,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -485,7 +486,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): @@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -777,7 +781,7 @@ class DatasetIndexingStatusApi(Resource): class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") @@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource): _, current_tenant_id = current_account_with_tenant() current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) - .count() + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -826,7 +833,7 @@ class DatasetApiKeyApi(Resource): @console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("delete_dataset_api_key") @console_ns.doc(description="Delete dataset API key") @@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource): def delete(self, api_key_id): _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.delete(key) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bc90c4ffbd..edb738aad8 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -10,7 +10,7 @@ import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field -from sqlalchemy import asc, desc, select +from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -27,8 +27,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.dataset_fields import dataset_fields from fields.document_fields import ( @@ -38,6 +37,8 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() ) if dataset_process_rule: mode = dataset_process_rule.mode @@ -330,21 +330,23 @@ class DatasetDocumentListApi(Resource): if fetch: for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) document.completed_segments = completed_segments document.total_segments = total_segments @@ -448,11 +450,11 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=knowledge_config.embedding_model_provider, @@ -462,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -521,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource): if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -586,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) + file_detail = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) if file_detail is None: @@ -672,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -723,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) - .count() + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) + ) + or 0 ) # Create a dictionary with document attributes and additional fields @@ -1258,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - log = ( - db.session.query(DocumentPipelineExecutionLog) - .filter_by(document_id=document_id) + log = db.session.scalar( + select(DocumentPipelineExecutionLog) + .where(DocumentPipelineExecutionLog.document_id == document_id) .order_by(DocumentPipelineExecutionLog.created_at.desc()) - .first() + .limit(1) ) if not log: return { @@ -1328,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3fd0f3b712..2fd84303d7 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,10 +26,11 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment @@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id): """Helper function to marshal segment and add summary information.""" from services.summary_index_service import SummaryIndexService - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore # Query summary for this segment (only enabled summaries) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None @@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource): # Add summary to each segment segments_with_summary = [] for segment in segments.items: - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore segment_dict["summary"] = summaries.get(segment.id) segments_with_summary.append(segment_dict) @@ -279,10 +280,10 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -333,9 +334,9 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -383,10 +384,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): payload = BatchImportPayload.model_validate(console_ns.payload or {}) upload_file_id = payload.upload_file_id - upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1)) if not upload_file: raise NotFound("UploadFile not found.") @@ -559,19 +560,19 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 86090bcd10..fc6896f123 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService def _build_dataset_detail_model(): @@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel): class BedrockRetrievalPayload(BaseModel): - retrieval_setting: dict[str, object] + retrieval_setting: "BedrockRetrievalSetting" query: str knowledge_id: str diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cd568cf835..699fa599c8 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -19,8 +19,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index a4498005d8..946fa599e6 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -10,8 +10,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index c5dadb75f5..977ae93c03 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,11 +21,12 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline @@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() def _create_pagination_parser(): @@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 3912cc73ca..9079fbc29a 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -37,9 +37,9 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 3ef1341abc..d533e6c5b1 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar +from sqlalchemy import select + from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]): del kwargs["pipeline_id"] - pipeline = ( - db.session.query(Pipeline) - .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id) - .first() + pipeline = db.session.scalar( + select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1) ) if not pipeline: diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index ffb9e5bb6e..bc78ee6d2d 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index fcd52d2818..ccdccceaa6 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -24,8 +24,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 15e1aea361..a72cf6328a 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -21,9 +21,9 @@ from controllers.console.explore.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index a8d8036f0f..26aa086aac 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -42,8 +42,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.app_fields import ( @@ -61,6 +59,8 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7801cee473..17dbbdd534 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -21,9 +21,9 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 7207f7fd1d..e37e78c966 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator @@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() + generator: BaseAppGenerator if app.mode == AppMode.ADVANCED_CHAT: generator = AdvancedChatAppGenerator() elif app.mode == AppMode.WORKFLOW: @@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource): ) -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): +def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App: query = select(App).where( App.id == workflow_run.app_id, App.tenant_id == workflow_run.tenant_id, diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 49162d4dae..2a46d2250a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -13,9 +13,9 @@ from controllers.common.errors import ( ) from controllers.console import console_ns from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index e2b504751b..764f488755 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -2,7 +2,7 @@ from flask_restx import Resource, fields from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 538c5fb561..f45b72f390 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -8,7 +8,7 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 0a9e54de99..2a6f37aec8 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -5,8 +5,8 @@ from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index db3b02ae94..b22b91706e 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -7,9 +7,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d7eceb656c..3c7b97d7fc 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource): ) if args.config_from == "predefined-model": - available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( - tenant_id=tenant_id, provider_name=provider + available_credentials = model_provider_service.get_provider_available_credentials( + tenant_id=tenant_id, + provider=provider, ) else: # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) normalized_model_type = args.model_type.to_origin_model_type() - available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model + available_credentials = model_provider_service.get_provider_model_available_credentials( + tenant_id=tenant_id, + provider=provider, + model_type=normalized_model_type, + model=args.model, ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ee537367c7..b3e344ccea 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -14,7 +14,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -200,7 +200,7 @@ class PluginDebuggingKeyApi(Resource): "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, } except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/list") @@ -215,7 +215,7 @@ class PluginListApi(Resource): try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @@ -232,7 +232,7 @@ class PluginListLatestVersionsApi(Resource): try: versions = PluginService.list_latest_versions(args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"versions": versions}) @@ -251,7 +251,7 @@ class PluginListInstallationsFromIdsApi(Resource): try: plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins}) @@ -266,7 +266,7 @@ class PluginIconApi(Resource): try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -286,7 +286,7 @@ class PluginAssetApi(Resource): binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) return send_file(io.BytesIO(binary), mimetype="application/octet-stream") except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upload/pkg") @@ -303,7 +303,7 @@ class PluginUploadFromPkgApi(Resource): try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -323,7 +323,7 @@ class PluginUploadFromGithubApi(Resource): try: response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -361,7 +361,7 @@ class PluginInstallFromPkgApi(Resource): try: response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -387,7 +387,7 @@ class PluginInstallFromGithubApi(Resource): args.package, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -407,7 +407,7 @@ class PluginInstallFromMarketplaceApi(Resource): try: response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -433,7 +433,7 @@ class PluginFetchMarketplacePkgApi(Resource): } ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/fetch-manifest") @@ -453,7 +453,7 @@ class PluginFetchManifestApi(Resource): {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()} ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks") @@ -471,7 +471,7 @@ class PluginFetchInstallTasksApi(Resource): try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/") @@ -486,7 +486,7 @@ class PluginFetchInstallTaskApi(Resource): try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete") @@ -501,7 +501,7 @@ class PluginDeleteInstallTaskApi(Resource): try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/delete_all") @@ -516,7 +516,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete/") @@ -531,7 +531,7 @@ class PluginDeleteInstallTaskItemApi(Resource): try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/marketplace") @@ -553,7 +553,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/github") @@ -580,7 +580,7 @@ class PluginUpgradeFromGithubApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/uninstall") @@ -598,7 +598,7 @@ class PluginUninstallApi(Resource): try: return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/permission/change") @@ -674,7 +674,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): provider_type=args.provider_type, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) @@ -705,7 +705,7 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource): credentials=args.credentials, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index b38f05795a..1273b85bc3 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -26,8 +26,8 @@ from core.mcp.mcp_client import MCPClient from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index ad78d2a623..feedf074b7 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -14,8 +14,8 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 9e3fb3a90b..2f1e2f28bd 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -70,22 +70,25 @@ class ToolFileApi(Resource): except Exception: raise UnsupportedFileTypeError() + mime_type = tool_file.mime_type + filename = tool_file.filename + response = Response( stream, - mimetype=tool_file.mimetype, + mimetype=mime_type, direct_passthrough=True, headers={}, ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args.as_attachment: - encoded_filename = quote(tool_file.name) + if args.as_attachment and filename: + encoded_filename = quote(filename) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" enforce_download_for_html( response, - mime_type=tool_file.mimetype, - filename=tool_file.name, + mime_type=mime_type, + filename=filename, extension=extension, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 52690a12e1..ed3278a28b 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden import services +from core.tools.signature import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index 74005217ef..b38994f055 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -16,12 +16,14 @@ api = ExternalApi( inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") from . import mail as _mail +from .app import dsl as _app_dsl from .plugin import plugin as _plugin from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) __all__ = [ + "_app_dsl", "_mail", "_plugin", "_workspace", diff --git a/api/controllers/inner_api/app/__init__.py b/api/controllers/inner_api/app/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/controllers/inner_api/app/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py new file mode 100644 index 0000000000..56730cf37a --- /dev/null +++ b/api/controllers/inner_api/app/dsl.py @@ -0,0 +1,110 @@ +"""Inner API endpoints for app DSL import/export. + +Called by the enterprise admin-api service. Import requires ``creator_email`` +to attribute the created app; workspace/membership validation is done by the +Go admin-api caller. +""" + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from controllers.common.schema import register_schema_model +from controllers.console.wraps import setup_required +from controllers.inner_api import inner_api_ns +from controllers.inner_api.wraps import enterprise_inner_api_only +from extensions.ext_database import db +from models import Account, App +from models.account import AccountStatus +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus + + +class InnerAppDSLImportPayload(BaseModel): + yaml_content: str = Field(description="YAML DSL content") + creator_email: str = Field(description="Email of the workspace member who will own the imported app") + name: str | None = Field(default=None, description="Override app name from DSL") + description: str | None = Field(default=None, description="Override app description from DSL") + + +register_schema_model(inner_api_ns, InnerAppDSLImportPayload) + + +@inner_api_ns.route("/enterprise/workspaces//dsl/import") +class EnterpriseAppDSLImport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc("enterprise_app_dsl_import") + @inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__]) + @inner_api_ns.doc( + responses={ + 200: "Import completed", + 202: "Import pending (DSL version mismatch requires confirmation)", + 400: "Import failed (business error)", + 404: "Creator account not found or inactive", + } + ) + def post(self, workspace_id: str): + """Import a DSL into a workspace on behalf of a specified creator.""" + args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {}) + + account = _get_active_account(args.creator_email) + if account is None: + return {"message": f"account '{args.creator_email}' not found or inactive"}, 404 + + account.set_tenant_id(workspace_id) + + with Session(db.engine) as session: + dsl_service = AppDslService(session) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=args.yaml_content, + name=args.name, + description=args.description, + ) + session.commit() + + if result.status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +@inner_api_ns.route("/enterprise/apps//dsl") +class EnterpriseAppDSLExport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc( + "enterprise_app_dsl_export", + responses={ + 200: "Export successful", + 404: "App not found", + }, + ) + def get(self, app_id: str): + """Export an app's DSL as YAML.""" + include_secret = request.args.get("include_secret", "false").lower() == "true" + + app_model = db.session.query(App).filter_by(id=app_id).first() + if not app_model: + return {"message": "app not found"}, 404 + + data = AppDslService.export_dsl( + app_model=app_model, + include_secret=include_secret, + ) + + return {"data": data}, 200 + + +def _get_active_account(email: str) -> Account | None: + """Look up an active account by email. + + Workspace membership is already validated by the Go admin-api caller. + """ + account = db.session.query(Account).filter_by(email=email).first() + if account is None or account.status != AccountStatus.ACTIVE: + return None + return account diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8b3950e6..72cab3de73 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -28,8 +28,8 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file.helpers import get_signed_file_url_for_plugin -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from core.tools.signature import get_signed_file_url_for_plugin +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 9ddaaa315b..869fb73cf5 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -9,8 +9,8 @@ from controllers.common.schema import register_schema_model from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request -from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 38d292d0b9..86d88ddafb 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -21,7 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 98f09c44a1..31f2797d66 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -28,7 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index f853a124ef..5e7847d784 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -4,6 +4,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from controllers.common.file_response import enforce_download_for_html from controllers.common.schema import register_schema_model @@ -102,27 +103,27 @@ class FilePreviewApi(Resource): raise FileAccessDeniedError("Invalid file or app identifier") # First, find the MessageFile that references this upload file - message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1)) if not message_file: raise FileNotFoundError("File not found in message context") # Get the message and verify it belongs to the requesting app - message = ( - db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1) ) if not message: raise FileAccessDeniedError("File access denied: not owned by requesting app") # Get the actual upload file record - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = db.session.get(UploadFile, file_id) if not upload_file: raise FileNotFoundError("Upload file record not found") # Additional security: verify tenant isolation - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app and upload_file.tenant_id != app.tenant_id: raise FileAccessDeniedError("File access denied: tenant mismatch") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 8b47a887bb..bc06e8f386 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import Forbidden from controllers.common.fields import Site as SiteResponse @@ -28,7 +29,7 @@ class AppSiteApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 35dd22c801..94afd47f7f 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -27,12 +27,12 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 83d07087ab..dcf788f7a8 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -14,10 +14,11 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -139,10 +140,10 @@ class DatasetListApi(DatasetApiResource): query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore else: - item["embedding_available"] = False + item["embedding_available"] = False # type: ignore else: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore response = { "data": data, "has_more": len(datasets) == query.limit, @@ -253,10 +259,10 @@ class DatasetApi(DatasetApiResource): raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d34b4124ae..2c094aa3e6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,7 +6,7 @@ from uuid import UUID from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") @@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # get documents @@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..28fa915117 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -3,6 +3,7 @@ from typing import Any from flask import request from flask_restx import marshal from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -17,9 +18,10 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -91,7 +93,9 @@ class SegmentApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -103,9 +107,9 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -149,7 +153,9 @@ class SegmentApi(DatasetApiResource): # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -157,9 +163,9 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -219,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -253,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -262,10 +272,10 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -300,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -343,7 +355,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -358,9 +372,9 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -401,7 +415,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -467,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -526,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 35aed40a59..5ac65fc4e6 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -3,7 +3,7 @@ from flask_restx import Resource from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 7aa5b2f092..1d52b8a737 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -9,6 +9,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan @@ -62,7 +63,7 @@ def validate_app_token( def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).where(App.id == api_token.app_id).first() + app_model = db.session.get(App, api_token.app_id) if not app_model: raise Forbidden("The app no longer exists.") @@ -72,7 +73,7 @@ def validate_app_token( if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() + tenant = db.session.get(Tenant, app_model.tenant_id) if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -106,8 +107,8 @@ def validate_app_token( else: # For service API without end-user context, ensure an Account is logged in # so services relying on current_account_with_tenant() work correctly. - tenant_owner_info = ( - db.session.query(Tenant, Account) + tenant_owner_info = db.session.execute( + select(Tenant, Account) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(Account, TenantAccountJoin.account_id == Account.id) .where( @@ -115,8 +116,7 @@ def validate_app_token( TenantAccountJoin.role == "owner", Tenant.status == TenantStatus.NORMAL, ) - .one_or_none() - ) + ).one_or_none() if tenant_owner_info: tenant_model, account = tenant_owner_info @@ -277,29 +277,28 @@ def validate_dataset_token( # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == api_token.tenant_id, ) - .first() + .limit(1) ) if not dataset: raise NotFound("Dataset not found.") if not dataset.enable_api: raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == api_token.tenant_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role.in_(["owner"])) .where(Tenant.status == TenantStatus.NORMAL) - .one_or_none() - ) # TODO: only owner information is required, so only one is returned. + ).one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).where(Account.id == ta.account_id).first() + account = db.session.get(Account, ta.account_id) # Login admin if account: account.current_tenant = tenant @@ -360,7 +359,9 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2b8f752668..8081dee0bd 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 8634c1f43c..0528184d79 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -25,7 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index aa56292614..4274b8c9ab 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -20,9 +20,9 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.enums import FeedbackRating diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 6a93ef6748..fe31e9d4ac 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -11,9 +11,9 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 508d1a756a..ccef6e5b7f 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -22,9 +22,9 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bdc8df813..a846cf4b0f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) +from core.app.file_access import DatabaseFileAccessController from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -26,8 +27,10 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from extensions.ext_database import db +from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, LLMUsage, PromptMessage, @@ -37,15 +40,14 @@ from dify_graph.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from extensions.ext_database import db -from factories import file_factory +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class BaseAgentRunner(AppRunner): @@ -138,6 +140,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + user_id=self.user_id, invoke_from=self.application_generate_entity.invoke_from, ) assert tool_entity.entity.description @@ -524,7 +527,10 @@ class BaseAgentRunner(AppRunner): image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config + message_files=files, + tenant_id=self.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) if not file_objs: return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9271ed10bd..0a0fdfdd29 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -15,8 +15,8 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, @@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tools=[], stop=app_generate_entity.model_conf.stop, stream=True, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 89451a0498..b3fc8d42e6 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,16 +1,16 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder class CotChatAgentRunner(CotAgentRunner): diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3023b9bc4d..51a30998ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,13 +1,13 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder class CotCompletionAgentRunner(CotAgentRunner): diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5e13a13b21..d38d24d1e7 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -11,8 +11,8 @@ from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessag from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, @@ -25,7 +25,7 @@ from dify_graph.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from models.model import Message logger = logging.getLogger(__name__) @@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): tools=prompt_messages_tools, stop=app_generate_entity.model_conf.stop, stream=self.stream_tool_call, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 82676f1ebd..c3e56fe011 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -4,7 +4,7 @@ from collections.abc import Generator from typing import Union from core.agent.entities import AgentScratchpadUnit -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 558b6e69a0..dbd7527fc6 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -4,10 +4,10 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: @@ -21,7 +21,7 @@ class ModelConfigConverter: """ model_config = app_config.model - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 0929f52e33..f279f769aa 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,9 +2,8 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -54,9 +53,12 @@ class ModelConfigManager: if not isinstance(config["model"], dict): raise ValueError("model must be of object type") + # Keep provider discovery and provider-backed model listing on the same + # request-scoped runtime so caller scope and provider caches stay aligned. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # model.provider - model_provider_factory = ModelProviderFactory(tenant_id) - provider_entities = model_provider_factory.get_providers() + provider_entities = assembly.model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") @@ -71,8 +73,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( + models = assembly.provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index b7073898d6..7715a5330a 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -7,7 +7,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 8de1224a89..6d63ae04d3 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -3,7 +3,7 @@ from typing import cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 95ea70bc40..c67412cc29 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -5,10 +5,10 @@ from typing import Any, Literal from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.file import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 0c4266fbeb..9092c1a17d 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any from constants import DEFAULT_FILE_NUMBER_LIMITS -from dify_graph.file import FileUploadConfig +from graphon.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index d2a9a73380..13ace32fd6 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,7 @@ import re from core.app.app_config.entities import RagPipelineVariableEntity -from dify_graph.variables.input_entities import VariableEntity +from graphon.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5d974335ff..d69a80e4a9 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -5,7 +5,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -22,8 +22,14 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -34,20 +40,15 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import ( - DraftVariableSaverFactory, -) -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -150,85 +151,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id ) - else: - file_objs = [] - # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + if invoke_from == InvokeFrom.DEBUGGER: + # always enable retriever resource in debugger mode + app_config.additional_features.show_retrieve_source = True # type: ignore - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + workflow_run_id=str(workflow_run_id), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) - if invoke_from == InvokeFrom.DEBUGGER: - # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True # type: ignore + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) - # init application generate entity - application_generate_entity = AdvancedChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) - - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - return self._generate( - workflow=workflow, - user=user, - invoke_from=invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - stream=streaming, - pause_state_config=pause_state_config, - ) + return self._generate( + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + stream=streaming, + pause_state_config=pause_state_config, + ) def resume( self, @@ -460,94 +463,91 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + is_first_conversation = conversation is None - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) - if is_first_conversation: - # update conversation features - conversation.override_model_configs = workflow.features - db.session.commit() - db.session.refresh(conversation) + if is_first_conversation: + # update conversation features + conversation.override_model_configs = workflow.features + db.session.commit() + db.session.refresh(conversation) - # get conversation dialogue count - # NOTE: dialogue_count should not start from 0, - # because during the first conversation, dialogue_count should be 1. - self._dialogue_count = get_thread_messages_length(conversation.id) + 1 + # get conversation dialogue count + # NOTE: dialogue_count should not start from 0, + # because during the first conversation, dialogue_count should be 1. + self._dialogue_count = get_thread_messages_length(conversation.id) + 1 - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - "context": context, - "variable_loader": variable_loader, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": context, + "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) - # db.session.refresh(user) - db.session.close() + worker_thread.start() - # return response or stream generator - response = self._handle_advanced_chat_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), - ) + # Capture the scalar fields needed by the response pipeline before + # releasing the request-scoped SQLAlchemy session. + workflow_snapshot = WorkflowSnapshot.from_workflow(workflow) + conversation_snapshot = ConversationSnapshot.from_conversation(conversation) + message_snapshot = MessageSnapshot.from_message(message) + db.session.close() - return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow_snapshot, + queue_manager=queue_manager, + conversation=conversation_snapshot, + message=message_snapshot, + user=user, + stream=stream, + draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), + ) + + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -648,10 +648,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): self, *, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, @@ -688,13 +688,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 66037696af..d21fce144e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -25,19 +25,24 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import ( + build_bootstrap_variables, + build_system_variables, + system_variables_to_mapping, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import Variable from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable @@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - system_inputs = SystemVariable( + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, conversation_id=self.conversation.id, @@ -132,6 +137,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs @@ -150,7 +156,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.application_generate_entity.inputs = new_inputs self.application_generate_entity.query = new_query - system_inputs.query = new_query + system_inputs = build_system_variables( + system_variables_to_mapping(system_inputs), + query=new_query, + ) # annotation reply if self.handle_annotation_reply( @@ -166,14 +175,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Create a variable pool. # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=new_inputs, - environment_variables=self._workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=conversation_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + conversation_variables=conversation_variables, + ), ) + root_node_id = get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs) # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) @@ -185,6 +197,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, + root_node_id=root_node_id, ) db.session.close() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f7b5030d33..3577ae139b 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,8 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime from threading import Thread from typing import Any, Union @@ -14,6 +16,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -65,24 +68,72 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent +from models.model import AppMode from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass(frozen=True, slots=True) +class WorkflowSnapshot: + id: str + tenant_id: str + features_dict: Mapping[str, Any] + + @classmethod + def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot": + return cls( + id=workflow.id, + tenant_id=workflow.tenant_id, + features_dict=dict(workflow.features_dict), + ) + + +@dataclass(frozen=True, slots=True) +class ConversationSnapshot: + id: str + mode: AppMode + + @classmethod + def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot": + return cls( + id=conversation.id, + mode=conversation.mode, + ) + + +@dataclass(frozen=True, slots=True) +class MessageSnapshot: + id: str + query: str + created_at: datetime + status: MessageStatus + answer: str + + @classmethod + def from_message(cls, message: Message) -> "MessageSnapshot": + return cls( + id=message.id, + query=message.query, + created_at=message.created_at, + status=message.status, + answer=message.answer, + ) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -91,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], stream: bool, dialogue_count: int, @@ -117,7 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( query=message.query, files=application_generate_entity.files, conversation_id=conversation.id, @@ -155,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: + def _seed_task_state_from_message(self, message: MessageSnapshot) -> None: if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer @@ -741,8 +792,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( tenant_id=self._workflow_tenant_id, + workflow_execution_id=self._workflow_run_id, ) - form = form_repository.get_form(self._workflow_run_id, node_id) + form = form_repository.get_form(node_id) if form is None: return None return form.id @@ -933,21 +985,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - message_files = [ - MessageFile( - message_id=message.id, - type=file["type"], - transfer_method=file["transfer_method"], - url=file["remote_url"], - belongs_to=MessageFileBelongsTo.ASSISTANT, - upload_file_id=file["related_id"], - created_by_role=CreatorUserRole.ACCOUNT - if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatorUserRole.END_USER, - created_by=message.from_account_id or message.from_end_user_id or "", + message_files: list[MessageFile] = [] + for file in self._recorded_files: + reference = file.get("reference") or file.get("related_id") + message_files.append( + MessageFile( + message_id=message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to=MessageFileBelongsTo.ASSISTANT, + upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None), + created_by_role=CreatorUserRole.ACCOUNT + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER, + created_by=message.from_account_id or message.from_end_user_id or "", + ) ) - for file in self._recorded_files - ] session.add_all(message_files) def _seed_graph_runtime_state_from_queue_manager(self) -> None: @@ -1003,13 +1057,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): return message def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 76a067d7b6..1a44cc235e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -21,9 +21,9 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService @@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args.get("files") or [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args.get("files") or [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) + # get tracing instance + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) - # init application generate entity - application_generate_entity = AgentChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - call_depth=0, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + call_depth=0, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) - # new thread with request context and contextvars - context = contextvars.copy_context() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "context": context, - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a81da2e91c..09ddce327e 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,10 +15,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index a92e3dd2ea..5c9ba4567a 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -6,7 +6,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20e6ac98ea..8e8ccf2b90 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,27 +1,89 @@ from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import NodeType -from dify_graph.file import File, FileUploadConfig -from dify_graph.repositories.draft_variable_repository import ( +from core.app.apps.draft_variable_saver import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) -from dify_graph.variables.input_entities import VariableEntityType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope +from extensions.ext_database import db from factories import file_factory +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl if TYPE_CHECKING: - from dify_graph.variables.input_entities import VariableEntity + from graphon.variables.input_entities import VariableEntity + + +@final +class _DebuggerDraftVariableSaver: + """Adapter that binds SQLAlchemy session setup outside the saver port.""" + + def __init__( + self, + *, + account: Account, + app_id: str, + node_id: str, + node_type: NodeType, + node_execution_id: str, + enclosing_node_id: str | None = None, + ) -> None: + self._account = account + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + with Session(db.engine) as session, session.begin(): + DraftVariableSaverImpl( + session=session, + app_id=self._app_id, + node_id=self._node_id, + node_type=self._node_type, + node_execution_id=self._node_execution_id, + enclosing_node_id=self._enclosing_node_id, + user=self._account, + ).save(process_data, outputs) class BaseAppGenerator: + _file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController() + + @staticmethod + def _bind_file_access_scope( + *, + tenant_id: str, + user: Account | EndUser, + invoke_from: InvokeFrom, + ) -> AbstractContextManager[None]: + """Bind request-scoped file ownership markers for downstream file lookups.""" + + user_id = getattr(user, "id", None) + if not isinstance(user_id, str) or not user_id: + return nullcontext() + + user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER + return bind_file_access_scope( + FileAccessScope( + tenant_id=tenant_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + ) + def _prepare_user_inputs( self, *, @@ -50,6 +112,7 @@ class BaseAppGenerator: allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE @@ -64,6 +127,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, list) @@ -226,32 +290,30 @@ class BaseAppGenerator: assert isinstance(account, Account) def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - return DraftVariableSaverImpl( - session=session, + return _DebuggerDraftVariableSaver( + account=account, app_id=app_id, node_id=node_id, node_type=node_type, node_execution_id=node_execution_id, enclosing_node_id=enclosing_node_id, - user=account, ) else: def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: + _ = app_id, node_id, node_type, node_execution_id, enclosing_node_id return NoopDraftVariableSaver() return draft_var_saver_factory diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 5addd41815..d1771452c5 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -20,8 +20,8 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) -from dify_graph.runtime import GraphRuntimeState from extensions.ext_redis import redis_client +from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -61,27 +61,30 @@ class AppQueueManager(ABC): listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() last_ping_time: int | float = 0 - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break + try: + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE - ) + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE + ) - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + finally: + self._graph_runtime_state = None # Release reference once consumers finish or close the generator. def stop_listen(self): """ @@ -90,7 +93,6 @@ class AppQueueManager(ABC): """ self._clear_task_belong_cache() self._q.put(None) - self._graph_runtime_state = None # Release reference to allow GC to reclaim memory def _clear_task_belong_cache(self) -> None: """ diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 11fcbb7561..4a4c8b535d 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -29,22 +29,22 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from extensions.ext_database import db +from graphon.file.enums import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError -from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file.models import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 91cf54c774..db3a98c7ac 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -20,9 +21,9 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService @@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = ChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - stream=streaming, - ) + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + stream=streaming, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f63b38fc86..077c5239f3 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -15,9 +15,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) @@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 6a8e436163..2a90fbdad0 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,7 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from dify_graph.runtime import GraphRuntimeState +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -30,10 +31,10 @@ class GraphRuntimeStateSupport: return self._resolve_graph_runtime_state(graph_runtime_state) def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: - system_variables = graph_runtime_state.variable_pool.system_variables - if not system_variables or not system_variables.workflow_execution_id: + workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_run_id: raise ValueError("workflow_execution_id missing from runtime state") - return str(system_variables.workflow_execution_id) + return workflow_run_id def _resolve_graph_runtime_state( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 621b0d8cf3..e4aa2ff650 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping, Sequence @@ -50,22 +51,23 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import ( +from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import FILE_MODEL_IDENTITY, File -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm @@ -111,11 +113,11 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - system_variables: SystemVariable, + system_variables: Sequence[Variable], ): self._application_generate_entity = application_generate_entity self._user = user - self._system_variables = system_variables + self._system_variables = system_variables_to_mapping(system_variables) self._workflow_inputs = self._prepare_workflow_inputs() # Disable truncation for SERVICE_API calls to keep backward compatibility. @@ -133,7 +135,7 @@ class WorkflowResponseConverter: # ------------------------------------------------------------------ def _prepare_workflow_inputs(self) -> Mapping[str, Any]: inputs = dict(self._application_generate_entity.inputs) - for field_name, value in self._system_variables.to_dict().items(): + for field_name, value in self._system_variables.items(): # TODO(@future-refactor): store system variables separately from user inputs so we don't # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. if field_name == SystemVariableKey.CONVERSATION_ID: @@ -318,13 +320,23 @@ class WorkflowResponseConverter: pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] expiration_times_by_form_id: dict[str, datetime] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_token_by_form_id: dict[str, str] = {} if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) + stmt = select( + HumanInputForm.id, + HumanInputForm.expiration_time, + HumanInputForm.form_definition, + ).where(HumanInputForm.id.in_(human_input_form_ids)) with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): + for form_id, expiration_time, form_definition in session.execute(stmt): expiration_times_by_form_id[str(form_id)] = expiration_time + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) responses: list[StreamResponse] = [] @@ -344,8 +356,8 @@ class WorkflowResponseConverter: form_content=reason.form_content, inputs=reason.inputs, actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, + display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False), + form_token=form_token_by_form_id.get(reason.form_id), resolved_default_values=reason.resolved_default_values, expiration_time=int(expiration_time.timestamp()), ), diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 002b914ef1..c418fe9759 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -20,9 +21,9 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras={}, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras={}, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator): model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict - # parse files - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=message.message_files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + # parse files + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) - else: - file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=list(file_objs), + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={}, + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - inputs=message.inputs, - query=message.query, - files=list(file_objs), - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras={}, - ) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 56a4519879..6bb1ecdcb1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -13,9 +13,9 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) @@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/dify_graph/repositories/draft_variable_repository.py b/api/core/app/apps/draft_variable_saver.py similarity index 65% rename from api/dify_graph/repositories/draft_variable_repository.py rename to api/core/app/apps/draft_variable_saver.py index b2ebfacffd..24018012c5 100644 --- a/api/dify_graph/repositories/draft_variable_repository.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -4,31 +4,30 @@ import abc from collections.abc import Mapping from typing import Any, Protocol -from sqlalchemy.orm import Session - -from dify_graph.enums import NodeType +from graphon.enums import NodeType class DraftVariableSaver(Protocol): @abc.abstractmethod - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + """Persist node draft variables for a completed execution.""" + raise NotImplementedError class DraftVariableSaverFactory(Protocol): @abc.abstractmethod def __call__( self, - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - pass + """Build a saver bound to a concrete node execution.""" + raise NotImplementedError class NoopDraftVariableSaver(DraftVariableSaver): - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + return None diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 44d10d79b8..fe61224ada 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -28,6 +28,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.file_reference import resolve_file_record_id from extensions.ext_database import db from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic @@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): transfer_method=file.transfer_method, belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, - upload_file_id=file.related_id, + upload_file_id=resolve_file_record_id(file.reference), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19d67eb108..48457b5326 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -18,6 +18,7 @@ import contexts from configs import dify_config from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager @@ -34,13 +35,14 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.repositories.factory import DifyCoreRepositoryFactory -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index e767766bdb..44d2450f74 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -12,19 +12,19 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db +from graphon.entities.graph_init_params import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow @@ -106,13 +106,14 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow=workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=files, user_id=user_id, app_id=app_config.app_id, @@ -142,19 +143,25 @@ class PipelineRunner(WorkflowBasedAppRunner): ) ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - rag_pipeline_variables=rag_pipeline_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=workflow.environment_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ) + root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id( + workflow.graph_dict + ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( graph_runtime_state=graph_runtime_state, - start_node_id=self.application_generate_entity.start_node_id, + start_node_id=root_node_id, workflow=workflow, user_from=user_from, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6fbe19a3b2..8ad6893a15 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -17,6 +17,7 @@ from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager @@ -30,15 +31,13 @@ from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom @@ -129,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator): graph_engine_layers: Sequence[GraphEngineLayer] = (), pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: - files: Sequence[Mapping[str, Any]] = args.get("files") or [] + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - # parse files - # TODO(QuantumGhost): Move file parsing logic to the API controller layer - # for better separation of concerns. - # - # For implementation reference, see the `_parse_file` function and - # `DraftWorkflowNodeRunApi` class which handle this properly. - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - system_files = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, - strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, - ) - - # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow, - ) - - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, - user_id=user.id if isinstance(user, Account) else user.session_id, - ) - - inputs: Mapping[str, Any] = args["inputs"] - - extras = { - **extract_external_trace_id_from_args(args), - } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) - # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args - # trigger shouldn't prepare user inputs - if self._should_prepare_user_inputs(args): - inputs = self._prepare_user_inputs( - user_inputs=inputs, - variables=app_config.variables, + # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + system_files = file_factory.build_from_mappings( + mappings=files, tenant_id=app_model.tenant_id, + config=file_extra_config, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + access_controller=self._file_access_controller, ) - # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - inputs=inputs, - files=list(system_files), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - call_depth=call_depth, - trace_manager=trace_manager, - workflow_execution_id=workflow_run_id, - extras=extras, - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if triggered_from is not None: - # Use explicitly provided triggered_from (for async triggers) - workflow_triggered_from = triggered_from - elif invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - root_node_id=root_node_id, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, - ) + inputs: Mapping[str, Any] = args["inputs"] + + extras = { + **extract_external_trace_id_from_args(args), + } + workflow_run_id = str(workflow_run_id or uuid.uuid4()) + # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args + # trigger shouldn't prepare user inputs + if self._should_prepare_user_inputs(args): + inputs = self._prepare_user_inputs( + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ) + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + inputs=inputs, + files=list(system_files), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + trace_manager=trace_manager, + workflow_execution_id=workflow_run_id, + extras=extras, + ) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if triggered_from is not None: + # Use explicitly provided triggered_from (for async triggers) + workflow_triggered_from = triggered_from + elif invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + root_node_id=root_node_id, + graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, + ) def resume( self, @@ -292,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager - queue_manager = WorkflowAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode, - ) - - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - # release database connection, because the following new thread operations may take a long time - db.session.close() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": context, - "variable_loader": variable_loader, - "root_node_id": root_node_id, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # release database connection, because the following new thread operations may take a long time + db.session.close() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": context, + "variable_loader": variable_loader, + "root_node_id": root_node_id, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - draft_var_saver_factory=draft_var_saver_factory, - stream=streaming, - ) + draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + draft_var_saver_factory=draft_var_saver_factory, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index caea8b6b95..c02c0b16e9 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -8,17 +8,18 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow @@ -91,12 +92,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=self.application_generate_entity.files, user_id=self._sys_user_id, app_id=app_config.app_id, @@ -104,12 +106,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=self._workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + ), ) + root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph = self._init_graph( @@ -120,7 +126,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, - root_node_id=self._root_node_id, + root_node_id=root_node_id, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445..e0c5b44ee4 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, @@ -55,12 +56,11 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser @@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory self._workflow = workflow - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( files=application_generate_entity.files, user_id=user_session_id, app_id=application_generate_entity.app_config.app_id, @@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): app_id=self._application_generate_entity.app_config.app_id, workflow_id=self._workflow.id, workflow_run_id=workflow_run_id, - created_from=created_from.value, + created_from=created_from, created_by_role=self._created_by_role, created_by=self._user_id, ) @@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): return response def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index adc6cce9af..d7d3bd27de 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -34,13 +34,22 @@ from core.app.entities.queue_entities import ( ) from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -66,10 +75,9 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from graphon.graph_events.graph import GraphRunAbortedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -156,6 +164,8 @@ class WorkflowBasedAppRunner: workflow: Workflow, single_iteration_run: Any | None = None, single_loop_run: Any | None = None, + *, + user_id: str, ) -> tuple[Graph, VariablePool, GraphRuntimeState]: """ Prepare graph, variable pool, and runtime state for single node execution @@ -173,14 +183,15 @@ class WorkflowBasedAppRunner: ValueError: If neither single_iteration_run nor single_loop_run is specified """ # Create initial runtime state with variable pool containing environment variables - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), environment_variables=workflow.environment_variables, ), - start_at=time.time(), ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: @@ -191,6 +202,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id=user_id, ) elif single_loop_run: graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( @@ -200,6 +212,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id=user_id, ) else: raise ValueError("Neither single_iteration_run nor single_loop_run is specified") @@ -216,6 +229,8 @@ class WorkflowBasedAppRunner: graph_runtime_state: GraphRuntimeState, node_type_filter_key: str, # 'iteration_id' or 'loop_id' node_type_label: str = "node", # 'iteration' or 'loop' for error messages + *, + user_id: str = "", ) -> tuple[Graph, VariablePool]: """ Get graph and variable pool for single node execution (iteration or loop). @@ -272,6 +287,8 @@ class WorkflowBasedAppRunner: graph_config["edges"] = edge_configs + typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] + # Create required parameters for Graph.init graph_init_params = GraphInitParams( workflow_id=workflow.id, @@ -279,7 +296,7 @@ class WorkflowBasedAppRunner: run_context=build_dify_run_context( tenant_id=workflow.tenant_id, app_id=self._app_id, - user_id="", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, ), @@ -291,26 +308,15 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) - # init graph - graph = Graph.init( - graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True - ) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id target_node_config = None - for node in node_configs: - if node.get("id") == node_id: + for node in typed_node_configs: + if node["id"] == node_id: target_node_config = node break if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") - target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) - # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) @@ -319,12 +325,31 @@ class WorkflowBasedAppRunner: # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool + preload_node_creation_variables( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + selectors=[ + selector + for node_config in typed_node_configs + for selector in get_node_creation_preload_selectors( + node_type=node_config["data"].type, + node_data=node_config["data"], + ) + ], + ) + try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=target_node_config["id"], + node_type=node_type, + node_data=target_node_config["data"], + variable_mapping=variable_mapping, + ) load_into_variable_pool( variable_loader=self._variable_loader, @@ -340,6 +365,14 @@ class WorkflowBasedAppRunner: tenant_id=workflow.tenant_id, ) + # init graph after constructor-time context has been loaded + graph = Graph.init( + graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True + ) + + if not graph: + raise ValueError("graph not found in workflow") + return graph, variable_pool @staticmethod @@ -408,7 +441,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeRetryEvent( @@ -448,7 +485,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeSucceededEvent( @@ -466,6 +507,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunFailedEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeFailedEvent( node_execution_id=event.id, @@ -475,7 +521,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, @@ -483,6 +529,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunExceptionEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeExceptionEvent( node_execution_id=event.id, @@ -492,7 +543,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3..d8d851c505 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,14 +7,16 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.file import File, FileUploadConfig -from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +DIFY_RUN_CONTEXT_KEY = "_dify" + + class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d2a36f2a0d..63857bfff2 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -7,10 +7,10 @@ from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.entities.pause_reason import PauseReason +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 46a8ab52f2..719027bd23 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -6,10 +6,10 @@ from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.nodes.human_input.entities import FormInput, UserAction +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 87d4772815..0bd904811a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,6 +4,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset from models.enums import CollectionBindingType, ConversationFromSource @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index 5ed1fadc41..d59f5125e3 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -2,7 +2,7 @@ import logging from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from dify_graph.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..e0f1759e5e 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py new file mode 100644 index 0000000000..a75ab9781b --- /dev/null +++ b/api/core/app/file_access/__init__.py @@ -0,0 +1,11 @@ +from .controller import DatabaseFileAccessController +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope + +__all__ = [ + "DatabaseFileAccessController", + "FileAccessControllerProtocol", + "FileAccessScope", + "bind_file_access_scope", + "get_current_file_access_scope", +] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py new file mode 100644 index 0000000000..300c187083 --- /dev/null +++ b/api/core/app/file_access/controller.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable + +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, get_current_file_access_scope + + +class DatabaseFileAccessController(FileAccessControllerProtocol): + """Workflow-layer authorization helper for database-backed file lookups. + + Tenant scoping remains mandatory. When the current execution belongs to an + end user, the lookup is additionally constrained to that end user's file + ownership markers. + """ + + _scope_getter: Callable[[], FileAccessScope | None] + + def __init__( + self, + *, + scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope, + ) -> None: + self._scope_getter = scope_getter + + def current_scope(self) -> FileAccessScope | None: + return self._scope_getter() + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where( + UploadFile.created_by_role == CreatorUserRole.END_USER, + UploadFile.created_by == resolved_scope.user_id, + ) + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(UploadFile, file_id) + + stmt = self.apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(ToolFile, file_id) + + stmt = self.apply_tool_file_filters( + select(ToolFile).where(ToolFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) diff --git a/api/core/app/file_access/protocols.py b/api/core/app/file_access/protocols.py new file mode 100644 index 0000000000..8bb3eb9924 --- /dev/null +++ b/api/core/app/file_access/protocols.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Protocol + +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile + +from .scope import FileAccessScope + + +class FileAccessControllerProtocol(Protocol): + """Contract for applying access rules to file lookups. + + Implementations translate an optional execution scope into query constraints + and authorized record retrieval. The contract is intentionally limited to + ownership and tenancy rules for workflow-layer file access. + """ + + def current_scope(self) -> FileAccessScope | None: + """Return the scope active for the current execution, if one exists. + + Callers use this to decide whether embedded file metadata may be trusted + or whether a fresh authorized lookup is required. + """ + ... + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + """Return an upload-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + """Return a tool-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + """Load one authorized upload-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + """Load one authorized tool-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py new file mode 100644 index 0000000000..80d504ef1c --- /dev/null +++ b/api/core/app/file_access/scope.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + +_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( + "current_file_access_scope", + default=None, +) + + +@dataclass(frozen=True, slots=True) +class FileAccessScope: + """Request-scoped ownership context used by workflow-layer file lookups.""" + + tenant_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + @property + def requires_user_ownership(self) -> bool: + return self.user_from == UserFrom.END_USER + + +def get_current_file_access_scope() -> FileAccessScope | None: + return _current_file_access_scope.get() + + +@contextmanager +def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: + token = _current_file_access_scope.set(scope) + try: + yield + finally: + _current_file_access_scope.reset(token) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d227e4e904..eeb9abbbfa 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,19 @@ +""" +Persist conversation-scoped variable updates emitted by the graph engine. + +The graph package emits generic variable update events and stays unaware of +conversation identity or storage concerns. This layer lives in the application +core, listens to those generic events, and persists only the `conversation.*` +scope updates that matter to chat applications. +""" + import logging -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.variables import VariableBase +from core.workflow.system_variables import SystemVariableKey, get_system_text +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -20,41 +27,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): pass def on_event(self, event: GraphEngineEvent) -> None: - if not isinstance(event, NodeRunSucceededEvent): - return - if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: - return - if self.graph_runtime_state is None: + if not isinstance(event, NodeRunVariableUpdatedEvent): return - updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] - if not updated_variables: + selector = event.variable.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) return - conversation_id = self.graph_runtime_state.system_variable.conversation_id + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) if conversation_id is None: return - updated_any = False - for item in updated_variables: - selector = item.selector - if len(selector) < 2: - logger.warning("Conversation variable selector invalid. selector=%s", selector) - continue - if selector[0] != CONVERSATION_VARIABLE_NODE_ID: - continue - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - logger.warning( - "Conversation variable not found in variable pool. selector=%s", - selector, - ) - continue - self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) - updated_any = True + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + return - if updated_any: - self._conversation_variable_updater.flush() + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 4370c01a0b..98e2257b1f 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -6,9 +6,10 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory @@ -119,7 +120,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer): generate_entity=entity_wrapper, ) - workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 2adaf14a35..172306f271 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,21 +1,28 @@ -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_events.graph import GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): """ """ + def __init__(self) -> None: + super().__init__() + self._paused = False + def on_graph_start(self): - pass + self._paused = False def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. """ if isinstance(event, GraphRunPausedEvent): - pass + self._paused = True def on_graph_end(self, error: Exception | None): """ """ - pass + self._paused = False + + def is_paused(self) -> bool: + return self._paused diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index d7ca45f209..fef12df504 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -4,9 +4,9 @@ from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore -from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a4019a83e1..781a0aa3d3 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -5,9 +5,10 @@ from typing import Any, ClassVar from pydantic import TypeAdapter from core.db.session_factory import session_factory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity @@ -59,7 +60,10 @@ class TriggerPostLayer(GraphEngineLayer): outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id - workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id, "Workflow run id is not set" total_tokens = self.graph_runtime_state.total_tokens diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa5..c49c4eb0ac 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,23 +2,34 @@ from __future__ import annotations from typing import Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.llm.entities import ModelConfig -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager - def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: - self.tenant_id = tenant_id - self.provider_manager = provider_manager or ProviderManager() + def __init__( + self, + *, + run_context: DifyRunContext, + provider_manager: ProviderManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if provider_manager is None: + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + self.provider_manager = provider_manager def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: provider_configurations = self.provider_manager.get_configurations(self.tenant_id) @@ -42,9 +53,21 @@ class DifyModelFactory: tenant_id: str model_manager: ModelManager - def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: - self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + def __init__( + self, + *, + run_context: DifyRunContext, + model_manager: ModelManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if model_manager is None: + model_manager = ModelManager( + provider_manager=create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + ) + self.model_manager = model_manager def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( @@ -55,18 +78,42 @@ class DifyModelFactory: ) -def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: - return ( - DifyCredentialsProvider(tenant_id=tenant_id), - DifyModelFactory(tenant_id=tenant_id), +def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]: + """Create LLM access adapters that share the same tenant-bound manager graph.""" + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, ) + model_manager = ModelManager(provider_manager=provider_manager) + + return ( + DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyModelFactory(run_context=run_context, model_manager=model_manager), + ) + + +def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """ + Split node-level completion params into provider parameters and stop sequences. + + Workflow LLM-compatible nodes still consume runtime invocation settings from + ``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the + ``ModelInstance`` view and the returned config entity aligned here so callers + do not need to duplicate normalization logic. + """ + normalized_parameters = dict(completion_params) + stop = normalized_parameters.pop("stop", []) + if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop): + stop = [] + + return normalized_parameters, stop def fetch_model_config( *, node_data_model: ModelConfig, credentials_provider: CredentialsProvider, - model_factory: ModelFactory, + model_factory: DifyModelFactory, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -80,22 +127,18 @@ def fetch_model_config( model_type=ModelType.LLM, ) if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") provider_model.raise_for_status() - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + if model_schema is None: + raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.") + parameters, stop = _normalize_completion_params(node_data_model.completion_params) model_instance.provider = node_data_model.provider model_instance.model_name = node_data_model.name model_instance.credentials = credentials - model_instance.parameters = completion_params + model_instance.parameters = parameters model_instance.stop = tuple(stop) return model_instance, ModelConfigWithCredentialsEntity( @@ -103,8 +146,8 @@ def fetch_model_config( model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=completion_params, + parameters=parameters, stop=stop, + provider_model_bundle=provider_model_bundle, ) diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 7aa3bf15ab..65a3f39d64 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -6,8 +6,8 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMUsage from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 0d5e0acec6..9e688589db 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -17,7 +17,7 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b530fe1ce4..cf9cb6d051 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -51,15 +51,15 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.enums import FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from events.message_event import message_was_created +from extensions.ext_database import db +from graphon.file.enums import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.message_event import message_was_created -from extensions.ext_database import db +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index fc8b6c6b5a..45f622c469 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,8 @@ from typing import TypedDict from core.tools.signature import sign_tool_file -from dify_graph.file import helpers as file_helpers -from dify_graph.file.enums import FileTransferMethod +from graphon.file import helpers as file_helpers +from graphon.file.enums import FileTransferMethod from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index e0f8d27111..aa5291bad5 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -1,33 +1,42 @@ from __future__ import annotations +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse from collections.abc import Generator +from typing import TYPE_CHECKING, Literal from configs import dify_config +from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file -from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from dify_graph.file.runtime import set_workflow_file_runtime +from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +from graphon.file.enums import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime + +if TYPE_CHECKING: + from graphon.file.models import File class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``dify_graph.file``.""" + """Production runtime wiring for ``graphon.file``. - @property - def files_url(self) -> str: - return dify_config.FILES_URL + Opaque file references are resolved back to canonical database records before + URLs are signed or storage keys are used. When a request-scoped file access + scope is present, those lookups additionally enforce tenant and end-user + ownership filters. + """ - @property - def internal_files_url(self) -> str | None: - return dify_config.INTERNAL_FILES_URL + _file_access_controller: FileAccessControllerProtocol - @property - def secret_key(self) -> str: - return dify_config.SECRET_KEY - - @property - def files_access_timeout(self) -> int: - return dify_config.FILES_ACCESS_TIMEOUT + def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None: + self._file_access_controller = file_access_controller @property def multimodal_send_format(self) -> str: @@ -39,9 +48,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + storage_key = self._resolve_storage_key(file=file) + data = storage.load(storage_key, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {storage_key} is not a bytes object") + return data + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + return file.remote_url + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if file.transfer_method == FileTransferMethod.LOCAL_FILE: + return self.resolve_upload_file_url( + upload_file_id=parsed_reference.record_id, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + self._assert_upload_file_access(upload_file_id=parsed_reference.record_id) + return sign_tool_file( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.TOOL_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + return self.resolve_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + return None + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._assert_upload_file_access(upload_file_id=upload_file_id) + base_url = self._base_url(for_external=for_external) + url = f"{base_url}/files/{upload_file_id}/file-preview" + query = self._sign_query(payload=f"file-preview|{upload_file_id}") + if as_attachment: + query["as_attachment"] = "true" + return f"{url}?{urllib.parse.urlencode(query)}" + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: + payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest() + if sign != base64.urlsafe_b64encode(recalculated).decode(): + return False + return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def _base_url(*, for_external: bool) -> str: + if for_external: + return dify_config.FILES_URL + return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + @staticmethod + def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + + def _sign_query(self, *, payload: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest() + return { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(sign).decode(), + } + + def _resolve_storage_key(self, *, file: File) -> str: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + + record_id = parsed_reference.record_id + with session_factory.create_session() as session: + if file.transfer_method in { + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + }: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id) + if upload_file is None: + raise ValueError(f"Upload file {record_id} not found") + return upload_file.key + + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id) + if tool_file is None: + raise ValueError(f"Tool file {record_id} not found") + return tool_file.file_key + + def _assert_upload_file_access(self, *, upload_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id) + if upload_file is None: + raise ValueError(f"Upload file {upload_file_id} not found") + + def _assert_tool_file_access(self, *, tool_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id) + if tool_file is None: + raise ValueError(f"Tool file {tool_file_id} not found") + def bind_dify_workflow_file_runtime() -> None: - set_workflow_file_runtime(DifyWorkflowFileRuntime()) + set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index faf1516c40..5666bf1191 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -9,20 +9,21 @@ from typing import TYPE_CHECKING, cast, final from typing_extensions import override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.base.node import Node +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase +from graphon.graph_events.node import NodeRunSucceededEvent +from graphon.nodes.base.node import Node if TYPE_CHECKING: - from dify_graph.nodes.llm.node import LLMNode - from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + from graphon.nodes.llm.node import LLMNode + from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer): return try: - dify_ctx = node.require_dify_context() + dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) deduct_llm_quota( tenant_id=dify_ctx.tenant_id, model_instance=model_instance, @@ -114,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer): try: match node.node_type: case BuiltinNodeTypes.LLM: - return cast("LLMNode", node).model_instance + model_instance = cast("LLMNode", node).model_instance case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - return cast("ParameterExtractorNode", node).model_instance + model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - return cast("QuestionClassifierNode", node).model_instance + model_instance = cast("QuestionClassifierNode", node).model_instance case _: return None except AttributeError: @@ -127,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer): node.id, ) return None + + if isinstance(model_instance, ModelInstance): + return model_instance + + raw_model_instance = getattr(model_instance, "_model_instance", None) + if isinstance(raw_model_instance, ModelInstance): + return raw_model_instance + + return None diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 4b20477a7f..837bf7ff81 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -16,10 +16,6 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_ from typing_extensions import override from configs import dify_config -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser import ( DefaultNodeOTelParser, LLMNodeOTelParser, @@ -28,6 +24,10 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 99b64b3ab5..e540733de2 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -17,17 +17,19 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution -from dify_graph.enums import ( - SystemVariableKey, +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities import WorkflowExecution, WorkflowNodeExecution +from graphon.enums import ( WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, WorkflowType, ) -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, @@ -42,9 +44,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from graphon.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now @@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._handle_graph_run_paused(event) return - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - return - if isinstance(event, NodeRunRetryEvent): self._handle_node_retry(event) return + if isinstance(event, NodeRunStartedEvent): + self._handle_node_started(event) + return + if isinstance(event, NodeRunSucceededEvent): self._handle_node_succeeded(event) return @@ -372,10 +372,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution.error = error if update_outputs: + projected_outputs = project_node_outputs_for_workflow_run( + node_type=domain_execution.node_type, + inputs=node_result.inputs, + outputs=node_result.outputs, + ) domain_execution.update_from_mapping( inputs=node_result.inputs, process_data=node_result.process_data, - outputs=node_result.outputs, + outputs=projected_outputs, metadata=node_result.metadata, ) diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index beda515666..9e3c187210 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -15,8 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: @@ -25,12 +25,10 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str): if not text_content or text_content.isspace(): return - return model_instance.invoke_tts( - content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) def _process_future( @@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher: self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") - self.model_manager = ModelManager() + self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts") self.model_instance = self.model_manager.get_default_model_instance( tenant_id=self.tenant_id, model_type=ModelType.TTS ) @@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.voice ) future_queue.put(futures_result) break @@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher: if len(sentence_arr) >= min(self.max_sentence, 7): self.max_sentence += 1 text_content = "".join(sentence_arr) - futures_result = self.executor.submit( - _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice - ) + futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice) future_queue.put(futures_result) if isinstance(text_tmp, str): self.msg_text = text_tmp diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 24243add17..fe40d8f0e5 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -214,6 +214,6 @@ class DatasourceFileManager: # init tool_file_parser -# from dify_graph.file.datasource_file_parser import datasource_file_manager +# from graphon.file.datasource_file_parser import datasource_file_manager # # datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 4fa941ae16..8a9875e4d7 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,6 +6,7 @@ from typing import Any, cast from sqlalchemy import select import contexts +from core.app.file_access import DatabaseFileAccessController from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.datasource_entities import ( @@ -24,18 +25,20 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from factories import file_factory +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.file import File, get_file_type_by_mime_type +from graphon.file.enums import FileTransferMethod, FileType +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class DatasourceManager: @@ -279,11 +282,15 @@ class DatasourceManager: if datasource_file is not None: mapping = { "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(mime_type), + "type": get_file_type_by_mime_type(mime_type), "transfer_method": FileTransferMethod.TOOL_FILE, "url": url, } - file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + file_out = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) elif mtype == DatasourceMessage.MessageType.TEXT: assert isinstance(message.message, DatasourceMessage.TextMessage) yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) @@ -351,11 +358,10 @@ class DatasourceManager: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference(record_id=str(upload_file.id)), size=upload_file.size, storage_key=upload_file.key, url=upload_file.source_url, diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 4c9ff64479..84dd653772 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 2881888e27..089b8b8e59 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,7 +4,8 @@ from mimetypes import guess_extension, guess_type from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) @@ -103,8 +104,14 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + reference = getattr(file, "reference", None) or getattr(file, "related_id", None) + parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None + if parsed_reference is None: + raise ValueError("datasource file is missing reference") + url = cls.get_datasource_file_url( + datasource_file_id=parsed_reference.record_id, + extension=file.extension, + ) if file.type == FileType.IMAGE: yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 89b48fd2ef..f49cbf9ffe 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,5 @@ -from enum import StrEnum, auto +"""Compatibility wrapper for the runtime embedding input enum.""" +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType -class EmbeddingInputType(StrEnum): - """ - Enum for embedding input type. - """ - - DOCUMENT = auto() - QUERY = auto() +__all__ = ["EmbeddingInputType"] diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 1343bd8e82..9d970d5db1 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field -from dify_graph.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index d214652e9c..bfa4f56915 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -15,7 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file import helpers as file_helpers +from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 3427fc54b1..e99a131500 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -3,9 +3,9 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel +from graphon.model_runtime.entities.provider_entities import ProviderEntity class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a9f2300ba2..d90afd3f7b 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,7 +7,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -19,15 +21,17 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel): - Load balancing configurations - Model enablement/disablement + Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so + nested schema and model lookups reuse the caller scope that was already + resolved by the composition layer. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) @model_validator(mode="after") def _(self): @@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self + def bind_model_runtime(self, model_runtime: ModelRuntime) -> None: + """Attach the already-composed runtime for request-bound call chains.""" + self._bound_model_runtime = model_runtime + + def get_model_provider_factory(self) -> ModelProviderFactory: + """Return a provider factory that preserves any request-bound runtime.""" + if self._bound_model_runtime is not None: + return ModelProviderFactory(model_runtime=self._bound_model_runtime) + return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a830f227a9..dffc7f2fc1 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -12,7 +12,7 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4251cfd30b..951e065b2c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -13,7 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from dify_graph.nodes.code.entities import CodeLanguage +from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index c569e066f4..b96a9ce380 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 873f6a4093..dc37a36943 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,11 +4,11 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 600a444357..eb762c3508 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..46bf1d6937 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -31,10 +31,10 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account @@ -50,7 +50,10 @@ logger = logging.getLogger(__name__) class IndexingRunner: def __init__(self): self.storage = storage - self.model_manager = ModelManager() + + @staticmethod + def _get_model_manager(tenant_id: str) -> ModelManager: + return ModelManager.for_tenant(tenant_id=tenant_id) def _handle_indexing_error(self, document_id: str, error: Exception) -> None: """Handle indexing errors by updating document status.""" @@ -271,7 +274,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,22 +292,22 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_default_model_instance( + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) @@ -573,8 +576,8 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_model_instance( + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -587,7 +590,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +600,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +631,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +657,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,16 +767,16 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance( tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index c8848336d9..3712374305 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -27,13 +27,13 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow @@ -62,7 +62,7 @@ class LLMGenerator: prompt += query + "\n" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -120,7 +120,7 @@ class LLMGenerator: prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -172,7 +172,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -219,7 +219,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -306,7 +306,7 @@ class LLMGenerator: remove_template_variables=False, ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -337,7 +337,7 @@ class LLMGenerator: def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -362,7 +362,7 @@ class LLMGenerator: @classmethod def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -536,7 +536,7 @@ class LLMGenerator: injected_instruction = injected_instruction.replace(CURRENT, current or "null") if ERROR_MESSAGE in injected_instruction: injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 77ea1713ea..81672ee7aa 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -10,22 +10,22 @@ from pydantic import TypeAdapter, ValidationError from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT from core.model_manager import ModelInstance -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule class ResponseFormat(StrEnum): @@ -55,7 +55,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload @@ -70,7 +69,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload @@ -85,7 +83,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( @@ -99,7 +96,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ @@ -113,7 +109,6 @@ def invoke_llm_with_structured_output( :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -143,7 +138,6 @@ def invoke_llm_with_structured_output( tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index de68eb268b..92d23c6dc9 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -7,7 +7,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index db9cb726d7..7b5a7635f1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -8,7 +8,7 @@ from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1156a98af1..658206128d 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -4,10 +4,13 @@ from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController from core.model_manager import ModelInstance from core.prompt.utils.extract_thread_messages import extract_thread_messages -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from extensions.ext_database import db +from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -15,14 +18,14 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from extensions.ext_database import db -from factories import file_factory +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +_file_access_controller = DatabaseFileAccessController() + class TokenBufferMemory: def __init__( @@ -85,7 +88,10 @@ class TokenBufferMemory: # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( - message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config + message_file=message_file, + tenant_id=app_record.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) for message_file in message_files ] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf..f5ff375f65 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,21 +7,22 @@ from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel from extensions.ext_redis import redis_client +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType @@ -30,7 +31,7 @@ logger = logging.getLogger(__name__) class ModelInstance: """ - Model instance class + Model instance class. """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): @@ -49,6 +50,13 @@ class ModelInstance: credentials=self.credentials, ) + def get_model_schema(self) -> AIModelEntity: + """Return the resolved schema for the current model instance.""" + model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for {self.model_name}") + return model_schema + @staticmethod def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ @@ -110,7 +118,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator: ... @@ -122,7 +129,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResult: ... @@ -134,7 +140,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... @@ -145,7 +150,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ @@ -156,7 +160,6 @@ class ModelInstance: :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -173,7 +176,6 @@ class ModelInstance: tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ), ) @@ -202,13 +204,12 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> EmbeddingResult: """ Invoke large language model :param texts: texts to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -221,7 +222,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, texts=texts, - user=user, input_type=input_type, ), ) @@ -229,14 +229,12 @@ class ModelInstance: def invoke_multimodal_embedding( self, multimodel_documents: list[dict], - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ Invoke large language model :param multimodel_documents: multimodel documents to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -249,7 +247,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, - user=user, input_type=input_type, ), ) @@ -279,7 +276,6 @@ class ModelInstance: docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -288,7 +284,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -303,7 +298,6 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) @@ -313,7 +307,6 @@ class ModelInstance: docs: list[dict], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -322,7 +315,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -337,16 +329,14 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) - def invoke_moderation(self, text: str, user: str | None = None) -> bool: + def invoke_moderation(self, text: str) -> bool: """ Invoke moderation model :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): @@ -358,16 +348,14 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, text=text, - user=user, ), ) - def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: + def invoke_speech2text(self, file: IO[bytes]) -> str: """ Invoke large language model :param file: audio file - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): @@ -379,18 +367,15 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, file=file, - user=user, ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: """ Invoke large language tts model :param content_text: text content to be translated - :param tenant_id: user tenant id :param voice: model timbre - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -402,8 +387,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, content_text=content_text, - user=user, - tenant_id=tenant_id, voice=voice, ), ) @@ -477,10 +460,20 @@ class ModelInstance: class ModelManager: - def __init__(self): - self._provider_manager = ProviderManager() + def __init__(self, provider_manager: ProviderManager): + self._provider_manager = provider_manager - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: + @classmethod + def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": + return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)) + + def get_model_instance( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + ) -> ModelInstance: """ Get model instance :param tenant_id: tenant id @@ -496,7 +489,8 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + model_instance = ModelInstance(provider_model_bundle, model) + return model_instance def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 06676f5cf4..35d4469bc1 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,6 @@ from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): @@ -50,7 +50,7 @@ class OpenAIModeration(Moderation): def _is_violated(self, inputs: dict): text = "\n".join(str(inputs.values())) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest" ) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 18f35b5b9c..76e81242f4 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -57,9 +57,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + return workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id + ) def build_workflow_node_span( self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 45319f24c1..43b204b78c 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -14,9 +14,9 @@ from core.ops.aliyun_trace.entities.semconv import ( GenAISpanKind, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 7cb54b2c88..724127c31c 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -39,6 +39,7 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +182,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance): arize_phoenix_config: ArizeConfig | PhoenixConfig, ): super().__init__(arize_phoenix_config) - import logging - - logging.basicConfig() - logging.getLogger().setLevel(logging.DEBUG) self.arize_phoenix_config = arize_phoenix_config self.tracer, self.processor = setup_tracer(arize_phoenix_config) self.project = arize_phoenix_config.project @@ -275,8 +272,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) try: @@ -304,7 +301,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_name": node_execution.title, "status": node_execution.status, "status_message": node_execution.error or "", - "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", + "level": "ERROR" if node_execution.status == WorkflowNodeExecutionStatus.FAILED else "DEFAULT", } ) @@ -365,7 +362,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) finally: - if node_execution.status == "failed": + if node_execution.status == WorkflowNodeExecutionStatus.FAILED: set_span_status(node_span, node_execution.error) else: set_span_status(node_span) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea63..45b2f635ba 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) + @property + def resolved_trace_id(self) -> str | None: + """Get trace_id with intelligent fallback. + + Priority: + 1. External trace_id (from X-Trace-Id header) + 2. workflow_run_id (if this trace type has it) + 3. message_id (as final fallback) + """ + if self.trace_id: + return self.trace_id + + # Try workflow_run_id (only exists on workflow-related traces) + workflow_run_id = getattr(self, "workflow_run_id", None) + if workflow_run_id: + return workflow_run_id + + # Final fallback to message_id + return str(self.message_id) if self.message_id else None + + @property + def resolved_parent_context(self) -> tuple[str | None, str | None]: + """Resolve cross-workflow parent linking from metadata. + + Extracts typed parent IDs from the untyped ``parent_trace_context`` + metadata dict (set by tool_node when invoking nested workflows). + + Returns: + (trace_correlation_override, parent_span_id_source) where + trace_correlation_override is the outer workflow_run_id and + parent_span_id_source is the outer node_execution_id. + """ + parent_ctx = self.metadata.get("parent_trace_context") + if not isinstance(parent_ctx, dict): + return None, None + trace_override = parent_ctx.get("parent_workflow_run_id") + parent_span = parent_ctx.get("parent_node_execution_id") + return ( + trace_override if isinstance(trace_override, str) else None, + parent_span if isinstance(parent_span, str) else None, + ) + @field_serializer("start_time", "end_time") def serialize_datetime(self, dt: datetime | None) -> str | None: if dt is None: @@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] + invoked_by: str | None = None query: str metadata: dict[str, Any] @@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,11 +246,31 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } +class OperationType(StrEnum): + """Operation type for token metric labels. + + Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output`` + counters so consumers can break down token usage by operation. + """ + + WORKFLOW = "workflow" + NODE_EXECUTION = "node_execution" + MESSAGE = "message" + RULE_GENERATE = "rule_generate" + CODE_GENERATE = "code_generate" + STRUCTURED_OUTPUT = "structured_output" + INSTRUCTION_MODIFY = "instruction_modify" + + class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" @@ -140,4 +278,6 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" + NODE_EXECUTION_TRACE = "node_execution" DATASOURCE_TRACE = "datasource" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6e62387a1f..4a634e2e57 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,8 +28,8 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 32a0c77fe2..9f7d73b4ca 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,8 +28,8 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index ab4a7650ec..8ec69e3542 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -23,8 +23,8 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fb72bc2381..a3ead548bb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -23,8 +23,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9ac753240b..0a2a0642f1 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,34 +15,179 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: - from dify_graph.entities import WorkflowExecution + from graphon.entities import WorkflowExecution logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined] + return str(name) if name else "" + + +def _lookup_llm_credential_info( + tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm" +) -> tuple[str | None, str]: + """ + Lookup LLM credential ID and name for the given provider and model. + Returns (credential_id, credential_name). + + Handles async timing issues gracefully - if credential is deleted between lookups, + returns the ID but empty name rather than failing. + """ + if not tenant_id or not provider: + return None, "" + + try: + with Session(db.engine) as session: + # Try to find provider-level or model-level configuration + provider_record = session.scalar( + select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM, + ) + ) + + if not provider_record: + return None, "" + + # Check if there's a model-specific config + credential_id = None + credential_name = "" + is_model_level = False + + if model: + # Try model-level first + model_record = session.scalar( + select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type, + ) + ) + + if model_record and model_record.credential_id: + credential_id = model_record.credential_id + is_model_level = True + + if not credential_id and provider_record.credential_id: + # Fall back to provider-level credential + credential_id = provider_record.credential_id + is_model_level = False + + # Lookup credential_name if we have credential_id + if credential_id: + try: + if is_model_level: + # Query ProviderModelCredential + cred_name = session.scalar( + select(ProviderModelCredential.credential_name).where( + ProviderModelCredential.id == credential_id + ) + ) + else: + # Query ProviderCredential + cred_name = session.scalar( + select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id) + ) + + if cred_name: + credential_name = str(cred_name) + except Exception as e: + # Credential might have been deleted between lookups (async timing) + # Return ID but empty name rather than failing + logger.warning( + "Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s", + credential_id, + provider, + model, + str(e), + exc_info=True, + ) + + return credential_id, credential_name + except Exception as e: + # Database query failed or other unexpected error + # Return empty rather than propagating error to telemetry emission + logger.warning( + "Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s", + tenant_id, + provider, + model, + str(e), + exc_info=True, + ) + return None, "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, key: str) -> dict[str, Any]: - match key: + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): } case _: - raise KeyError(f"Unsupported tracing provider: {key}") + raise KeyError(f"Unsupported tracing provider: {provider}") provider_config_map = OpsTraceProviderConfigMap() @@ -314,6 +459,10 @@ class OpsTraceManager: if app_id is None: return None + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: @@ -466,8 +615,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: @@ -478,6 +625,77 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _calculate_workflow_token_split( + cls, session: "Session", workflow_run_id: str, tenant_id: str + ) -> tuple[int, int]: + """Sum prompt/completion tokens across all node executions for a workflow run. + + Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens`` + and ``usage.completion_tokens``) rather than ``execution_metadata``, which only + carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading + large JSON blobs unnecessarily. + """ + import json + + from models.workflow import WorkflowNodeExecutionModel + + rows = ( + session.execute( + select(WorkflowNodeExecutionModel.outputs).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ) + .scalars() + .all() + ) + + total_prompt = 0 + total_completion = 0 + + for raw in rows: + if not raw: + continue + try: + outputs = json.loads(raw) if isinstance(raw, str) else raw + except (ValueError, TypeError): + continue + if not isinstance(outputs, dict): + continue + usage = outputs.get("usage") + if not isinstance(usage, dict): + continue + prompt = usage.get("prompt_tokens") + if isinstance(prompt, (int, float)): + total_prompt += int(prompt) + completion = usage.get("completion_tokens") + if isinstance(completion, (int, float)): + total_completion += int(completion) + + return (total_prompt, total_completion) + + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + def __init__( self, trace_type: Any, @@ -491,6 +709,7 @@ class TraceTask: self.trace_type = trace_type self.message_id = message_id self.workflow_run_id = workflow_execution.id_ if workflow_execution else None + self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer @@ -498,6 +717,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -509,9 +730,12 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + workflow_run_id=self.workflow_run_id, + conversation_id=self.conversation_id, + user_id=self.user_id, + total_tokens_override=self.workflow_total_tokens, ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -527,6 +751,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -541,6 +768,7 @@ class TraceTask: workflow_run_id: str | None, conversation_id: str | None, user_id: str | None, + total_tokens_override: int | None = None, ): if not workflow_run_id: return {} @@ -560,7 +788,7 @@ class TraceTask: workflow_run_version = workflow_run.version error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -581,8 +809,18 @@ class TraceTask: Message.workflow_run_id == workflow_run_id, ) message_id = session.scalar(message_data_stmt) + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + session, workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) - metadata = { + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -595,8 +833,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -611,6 +855,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -618,10 +864,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -644,6 +891,19 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -655,7 +915,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -672,7 +939,9 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=message_data.updated_at + if message_data.updated_at and message_data.updated_at > created_at + else created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -697,6 +966,8 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -738,6 +1009,8 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -777,6 +1050,52 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + + # Extract rerank model info from retrieval_model kwargs + rerank_model_provider = "" + rerank_model_name = "" + if "retrieval_model" in kwargs: + retrieval_model = kwargs["retrieval_model"] + if isinstance(retrieval_model, dict): + reranking_model = retrieval_model.get("reranking_model") + if isinstance(reranking_model, dict): + rerank_model_provider = reranking_model.get("reranking_provider_name", "") + rerank_model_name = reranking_model.get("reranking_model_name", "") + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -787,13 +1106,23 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, + "rerank_model_provider": rerank_model_provider, + "rerank_model_name": rerank_model_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -836,6 +1165,10 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() @@ -890,6 +1223,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -904,6 +1239,182 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names( + node_data.get("app_id"), node_data.get("tenant_id") + ) + else: + app_name, workspace_name = "", "" + + # Try tool credential lookup first + credential_id = node_data.get("credential_id") + if is_enterprise_telemetry_enabled(): + credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type")) + # If no credential_id found (e.g., LLM nodes), try LLM credential lookup + if not credential_id: + llm_cred_id, llm_cred_name = _lookup_llm_credential_info( + tenant_id=node_data.get("tenant_id"), + provider=node_data.get("model_provider"), + model=node_data.get("model_name"), + model_type="llm", + ) + if llm_cred_id: + credential_id = llm_cred_id + credential_name = llm_cred_name + else: + credential_name = "" + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "invoke_from": node_data.get("invoke_from"), + "credential_id": credential_id, + "credential_name": credential_name, + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + if conversation_id: + metadata["conversation_id"] = conversation_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -937,13 +1448,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -979,20 +1494,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 0a6013e244..4f06458157 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -41,7 +41,7 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 7e56b1effa..1b1b1025bc 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -24,11 +24,11 @@ from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( +from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from dify_graph.nodes import BuiltinNodeTypes -from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id) return list(executions) except Exception: diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index ef1a3be45b..ed6a7dabbb 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): if field_name == "inputs": data = { "messages": [ - dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v + dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore + for msg in v ] if isinstance(v, list) else v, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2a657b672c..a55505822a 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -31,8 +31,8 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) # rearrange workflow_node_executions by starting time diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 11c9191bac..85625fc87d 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -18,22 +18,39 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, UserPromptMessage, ) +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant class PluginModelBackwardsInvocation(BaseBackwardsInvocation): + @staticmethod + def _get_bound_model_instance( + *, + tenant_id: str, + user_id: str | None, + provider: str, + model_type: ModelType, + model: str, + ): + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=provider, + model_type=model_type, + model=model, + ) + @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, ) if isinstance(response, Generator): @@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm with structured output """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, model_parameters=payload.completion_params, ) @@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke text embedding """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_text_embedding( - texts=payload.texts, - user=user_id, - ) + response = model_instance.invoke_text_embedding(texts=payload.texts) return response @@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke rerank """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): docs=payload.docs, score_threshold=payload.score_threshold, top_n=payload.top_n, - user=user_id, ) return response @@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke tts """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_tts( - content_text=payload.content_text, - tenant_id=tenant.id, - voice=payload.voice, - user=user_id, - ) + response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) def handle() -> Generator[dict, None, None]: for chunk in response: @@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke speech2text """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( - file=temp, - user=user_id, - ) + response = model_instance.invoke_speech2text(file=temp) return { "result": response, @@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke moderation """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_moderation( - text=payload.text, - user=user_id, - ) + response = model_instance.invoke_moderation(text=payload.text) return { "result": response, } @classmethod - def get_system_model_max_tokens(cls, tenant_id: str) -> int: + def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int: """ get system model max tokens """ - return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id) @classmethod - def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ get prompt tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + return ModelInvocationUtils.calculate_tokens( + tenant_id=tenant_id, + prompt_messages=prompt_messages, + user_id=user_id, + ) @classmethod def invoke_system_model( @@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tool_type=ToolProviderType.PLUGIN, tool_name="plugin", prompt_messages=prompt_messages, + caller_user_id=user_id, ) @classmethod @@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke summary """ - max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id) content = payload.text SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -325,6 +337,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=content)], + user_id=user_id, ) < max_tokens * 0.6 ): @@ -337,6 +350,7 @@ Here is the extra instruction you need to follow: SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), UserPromptMessage(content=content), ], + user_id=user_id, ) def summarize(content: str) -> str: @@ -394,6 +408,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=result)], + user_id=user_id, ) > max_tokens * 0.7 ): diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index d6aef93fc4..248f8ef3e6 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,17 +1,17 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) from services.workflow_service import WorkflowService diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index c2d1574e67..0585494269 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id + tool_type, + tenant_id, + provider, + tool_name, + tool_parameters, + user_id=user_id, + credential_id=credential_id, ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 81e1e12c5f..1bd239a831 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -4,7 +4,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 7a3780f7de..6aefc41400 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -13,7 +13,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 416e0f6b4d..864e4b8dd7 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -16,8 +16,8 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index c15e9b0385..704cacae2a 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.provider_entities import BasicProviderConfig from core.plugin.utils.http_parser import deserialize_response -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -17,17 +17,17 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 737d204105..44047911da 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -13,6 +13,7 @@ from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -27,14 +28,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( @@ -235,7 +236,10 @@ class BasePluginClient: response.raise_for_status() except httpx.HTTPStatusError as e: logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) - raise e + if e.response.status_code < 500: + raise PluginDaemonClientSideError(description=str(e)) + else: + raise PluginDaemonInternalServerError(description=str(e)) except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" logger.exception("Failed to request plugin daemon, url: %s", path) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 49ee5d79cb..c91fa71374 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO +from typing import IO, Any from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -13,15 +13,22 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): + @staticmethod + def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {"data": data} + if user_id is not None: + payload["user_id"] = user_id + return payload + def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient): def get_model_schema( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/schema", PluginModelSchemaEntity, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict ) -> bool: """ validate the credentials of the provider @@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient): def validate_model_credentials( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_model_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient): def invoke_llm( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/invoke", type_=LLMResultChunk, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "llm", "model": model, @@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient): "stop": stop, "stream": stream, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient): def get_llm_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", type_=PluginLLMNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, @@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient): "prompt_messages": prompt_messages, "tools": tools, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient): def invoke_text_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient): "texts": texts, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient): "documents": documents, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient): def get_text_embedding_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, "credentials": credentials, "texts": texts, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient): def invoke_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient): def invoke_tts( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, @@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient): "content_text": content_text, "voice": voice, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient): def get_tts_model_voices( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type_=PluginVoicesResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, "credentials": credentials, "language": language, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient): def invoke_speech_to_text( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "speech2text", "model": model, "credentials": credentials, "file": binascii.hexlify(file.read()).decode(), }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient): def invoke_moderation( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/moderation/invoke", type_=PluginBasicBooleanResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "moderation", "model": model, "credentials": credentials, "text": text, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py new file mode 100644 index 0000000000..e3fba4ef3a --- /dev/null +++ b/api/core/plugin/impl/model_runtime.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import hashlib +import logging +from collections.abc import Generator, Iterable, Sequence +from threading import Lock +from typing import IO, Any, Union + +from pydantic import ValidationError +from redis import RedisError + +from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient +from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime +from models.provider_ids import ModelProviderID + +logger = logging.getLogger(__name__) + +# `TS` means tenant scope +TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" + + +class PluginModelRuntime(ModelRuntime): + """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + + tenant_id: str + user_id: str | None + client: PluginModelClient + _provider_entities: tuple[ProviderEntity, ...] | None + _provider_entities_lock: Lock + + def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + if client is None: + raise ValueError("client is required.") + self.tenant_id = tenant_id + self.user_id = user_id + self.client = client + self._provider_entities = None + self._provider_entities_lock = Lock() + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: + if self._provider_entities is not None: + return self._provider_entities + + with self._provider_entities_lock: + if self._provider_entities is None: + self._provider_entities = tuple( + self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) + ) + + return self._provider_entities + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + provider_schema = self._get_provider_schema(provider) + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + file_name = ( + provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + ) + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + file_name = ( + provider_schema.icon_small_dark.zh_Hans + if lang.lower() == "zh_hans" + else provider_schema.icon_small_dark.en_US + ) + else: + raise ValueError(f"Unsupported icon type: {icon_type}.") + + if not file_name: + raise ValueError(f"Provider {provider} does not have icon.") + + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + credentials=credentials, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_model_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + cache_key = self._get_schema_cache_key( + provider=provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + + cached_schema_json = None + try: + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + plugin_id, provider_name = self._split_provider(provider) + schema = self.client.get_model_schema( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_llm( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: + return 0 + + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_llm_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, + ) + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + documents=documents, + input_type=input_type, + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_text_embedding_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + ) + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_tts( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_tts_model_voices( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + language=language, + ) + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_speech_to_text( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + file=file, + ) + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_moderation( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + text=text, + ) + + def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = self._get_provider_short_name_alias(provider) + return declaration + + def _get_provider_schema(self, provider: str) -> ProviderEntity: + providers = self.fetch_model_providers() + provider_entity = next((item for item in providers if item.provider == provider), None) + if provider_entity is None: + provider_entity = next((item for item in providers if provider == item.provider_name), None) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + return provider_entity + + def _get_schema_cache_key( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> str: + # The plugin daemon distinguishes ``None`` from an explicit empty-string + # caller id, so the cache must only collapse ``None`` into tenant scope. + cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" + sorted_credentials = sorted(credentials.items()) if credentials else [] + if not sorted_credentials: + return cache_key + hashed_credentials = ":".join( + [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] + ) + return f"{cache_key}:{hashed_credentials}" + + def _split_provider(self, provider: str) -> tuple[str, str]: + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py new file mode 100644 index 0000000000..35abd2ae8c --- /dev/null +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + +if TYPE_CHECKING: + from core.model_manager import ModelManager + from core.plugin.impl.model_runtime import PluginModelRuntime + from core.provider_manager import ProviderManager + + +class PluginModelAssembly: + """Compose request-scoped model views on top of a single plugin runtime.""" + + tenant_id: str + user_id: str | None + _model_runtime: PluginModelRuntime | None + _model_provider_factory: ModelProviderFactory | None + _provider_manager: ProviderManager | None + _model_manager: ModelManager | None + + def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None: + self.tenant_id = tenant_id + self.user_id = user_id + self._model_runtime = None + self._model_provider_factory = None + self._provider_manager = None + self._model_manager = None + + @property + def model_runtime(self) -> PluginModelRuntime: + if self._model_runtime is None: + self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id) + return self._model_runtime + + @property + def model_provider_factory(self) -> ModelProviderFactory: + if self._model_provider_factory is None: + self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + return self._model_provider_factory + + @property + def provider_manager(self) -> ProviderManager: + if self._provider_manager is None: + from core.provider_manager import ProviderManager + + self._provider_manager = ProviderManager(model_runtime=self.model_runtime) + return self._provider_manager + + @property + def model_manager(self) -> ModelManager: + if self._model_manager is None: + from core.model_manager import ModelManager + + self._model_manager = ModelManager(provider_manager=self.provider_manager) + return self._model_manager + + +def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly: + """Create a request-scoped assembly that shares one plugin runtime across model views.""" + return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id) + + +def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: + """Create a plugin runtime with its client dependency fully composed.""" + from core.plugin.impl.model_runtime import PluginModelRuntime + + return PluginModelRuntime( + tenant_id=tenant_id, + user_id=user_id, + client=PluginModelClient(), + ) + + +def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: + """Create a tenant-bound model provider factory for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory + + +def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: + """Create a tenant-bound provider manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager + + +def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: + """Create a tenant-bound model manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index 0bbb62af93..ec4858ae2e 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/decode/from_identifier", PluginDecodeResponse, - data={"plugin_unique_identifier": plugin_unique_identifier}, - headers={"Content-Type": "application/json"}, + params={"plugin_unique_identifier": plugin_unique_identifier}, ) def fetch_plugin_installation_by_ids( diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 53bcd9e9c6..322f78ab4e 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,7 @@ from typing import Any from core.tools.entities.tool_entities import ToolSelector -from dify_graph.file.models import File +from graphon.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index ce9f7e64b2..de87a09652 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -8,9 +8,9 @@ from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.file.models import File +from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -18,8 +18,8 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.runtime import VariablePool +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index d09a46bfde..8f1d51f08a 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -5,12 +5,12 @@ from core.app.entities.app_invoke_entities import ( ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 667f5ef099..b98fd8c179 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,50 +1,7 @@ -from typing import Literal +from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """ - Chat Message. - """ - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """ - Completion Model Prompt Template. - """ - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - - class RolePrefix(BaseModel): - """ - Role Prefix. - """ - - user: str - assistant: str - - class WindowConfig(BaseModel): - """ - Window Config. - """ - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 951736831f..6ff2f44cdc 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 10c44349ae..e091215b80 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -10,8 +10,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, @@ -22,7 +22,7 @@ from dify_graph.model_runtime.entities.message_entities import ( from models.model import AppMode if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file.models import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 85a2201395..ba76eb0c4e 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import Any, cast from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd2..79fd78fe80 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextlib import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -28,17 +30,17 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.position_helper import is_filtered -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from extensions import ext_hosting_provider +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from extensions import ext_hosting_provider -from extensions.ext_database import db -from extensions.ext_redis import redis_client +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, @@ -53,15 +55,25 @@ from models.provider import ( from models.provider_ids import ModelProviderID from services.feature_service import FeatureService +if TYPE_CHECKING: + from graphon.model_runtime.runtime import ModelRuntime + class ProviderManager: """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + ProviderManager manages tenant-scoped model provider configuration. + + The runtime adapter is injected by the composition layer so this class stays + focused on configuration assembly instead of constructing plugin runtimes. + Request-bound managers may carry caller identity in that runtime, and the + resulting ``ProviderConfiguration`` objects must reuse it for downstream + model-type and schema lookups. """ - def __init__(self): + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None + self._model_runtime = model_runtime def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -127,7 +139,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -255,6 +267,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) + provider_configuration.bind_model_runtime(self._model_runtime) provider_configurations[str(provider_id_entity)] = provider_configuration @@ -321,7 +334,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( @@ -392,7 +405,7 @@ class ProviderManager: # create default model default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.value, + model_type=model_type.to_origin_model_type(), provider_name=provider, model_name=model, ) @@ -918,11 +931,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 33eb5f963a..2c81653559 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -8,8 +8,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): @@ -52,11 +52,10 @@ class DataPostProcessor: documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) @@ -106,9 +105,9 @@ class DataPostProcessor: ) -> ModelInstance | None: if reranking_model: try: - model_manager = ModelManager() - reranking_provider_name = reranking_model["reranking_provider_name"] - reranking_model_name = reranking_model["reranking_model_name"] + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 713319ab9d..1e4aa24287 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -23,8 +23,8 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -328,7 +328,7 @@ class RetrievalService: str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) if dataset.is_multimodal: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, provider=reranking_model["reranking_provider_name"], diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 144d834495..9f5842e449 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore from pymochow.model.database import Database from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore from pymochow.model.schema import ( + AutoBuildRowCountIncrement, Field, FilteringIndex, HNSWParams, @@ -51,6 +52,9 @@ class BaiduConfig(BaseModel): replicas: int = 3 inverted_index_analyzer: str = "DEFAULT_ANALYZER" inverted_index_parser_mode: str = "COARSE_MODE" + auto_build_row_count_increment: int = 500 + auto_build_row_count_increment_ratio: float = 0.05 + rebuild_index_timeout_in_seconds: int = 300 @model_validator(mode="before") @classmethod @@ -107,18 +111,6 @@ class BaiduVector(BaseVector): rows.append(row) table.upsert(rows=rows) - # rebuild vector index after upsert finished - table.rebuild_index(self.vector_index) - timeout = 3600 # 1 hour timeout - start_time = time.time() - while True: - time.sleep(1) - index = table.describe_index(self.vector_index) - if index.state == IndexState.NORMAL: - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") - def text_exists(self, id: str) -> bool: res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: @@ -232,8 +224,14 @@ class BaiduVector(BaseVector): return self._client.database(self._client_config.database) def _table_existed(self) -> bool: - tables = self._db.list_table() - return any(table.table_name == self._collection_name for table in tables) + try: + table = self._db.table(self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + return False + else: + raise + return True def _create_table(self, dimension: int): # Try to grab distributed lock and create table @@ -287,6 +285,11 @@ class BaiduVector(BaseVector): field=VDBField.VECTOR, metric_type=metric_type, params=HNSWParams(m=16, efconstruction=200), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement( + row_count_increment=self._client_config.auto_build_row_count_increment, + row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio, + ), ) ) @@ -335,7 +338,7 @@ class BaiduVector(BaseVector): ) # Wait for table created - timeout = 300 # 5 minutes timeout + timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout start_time = time.time() while True: time.sleep(1) @@ -345,6 +348,20 @@ class BaiduVector(BaseVector): if time.time() - start_time > timeout: raise TimeoutError(f"Table creation timeout after {timeout} seconds") redis_client.set(table_exist_cache_key, 1, ex=3600) + # rebuild vector index immediately after table created, make sure index is ready + table.rebuild_index(self.vector_index) + timeout = 3600 # 1 hour timeout + self._wait_for_index_ready(table, timeout) + + def _wait_for_index_ready(self, table, timeout: int = 3600): + start_time = time.time() + while True: + time.sleep(1) + index = table.describe_index(self.vector_index) + if index.state == IndexState.NORMAL: + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") class BaiduVectorFactory(AbstractVectorFactory): @@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory): replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, + auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT, + auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO, + rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS, ), ) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index c7b6593a8f..df02c584ed 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector): ) ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) docs = [] for doc, score in docs_and_scores: - score_threshold = float(kwargs.get("score_threshold") or 0.0) if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score - docs.append(doc) + docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 71b6fa0a9b..3c1d5e015f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c149241..06b17b9e62 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd12cd3fae..a77458706a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -14,10 +14,10 @@ from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile @@ -303,7 +303,7 @@ class Vector: redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..369159767e 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,9 +6,10 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding @@ -71,8 +72,8 @@ class DatasetDocumentStore: if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 6d1b65a055..b12a0ae2d6 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -10,10 +10,10 @@ from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding @@ -21,9 +21,8 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: str | None = None): + def __init__(self, model_instance: ModelInstance): self._model_instance = model_instance - self._user = user def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" @@ -65,7 +64,7 @@ class CacheEmbedding(Embeddings): batch_texts = embedding_queue_texts[i : i + max_chunks] embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT ) for vector in embedding_result.embeddings: @@ -147,7 +146,6 @@ class CacheEmbedding(Embeddings): embedding_result = self._model_instance.invoke_multimodal_embedding( multimodel_documents=batch_multimodel_documents, - user=self._user, input_type=EmbeddingInputType.DOCUMENT, ) @@ -202,7 +200,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + texts=[text], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] @@ -245,7 +243,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_multimodal_embedding( - multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 371f7b0865..e1ddd2dd96 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -95,15 +95,11 @@ class FirecrawlApp: if response.status_code == 200: crawl_status_response = response.json() if crawl_status_response.get("status") == "completed": - total = crawl_status_response.get("total", 0) - if total == 0: + # Normalize to avoid None bypassing the zero-guard when the API returns null. + total = crawl_status_response.get("total") or 0 + if total <= 0: raise Exception("Failed to check crawl status. Error: No page found") - data = crawl_status_response.get("data", []) - url_data_list: list[FirecrawlDocumentData] = [] - for item in data: - if isinstance(item, dict) and "metadata" in item and "markdown" in item: - url_data = self._extract_common_fields(item) - url_data_list.append(url_data) + url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers) if url_data_list: file_key = "website_files/" + job_id + ".txt" try: @@ -120,6 +116,36 @@ class FirecrawlApp: self._handle_error(response, "check crawl status") raise RuntimeError("unreachable: _handle_error always raises") + def _collect_all_crawl_pages( + self, first_page: dict[str, Any], headers: dict[str, str] + ) -> list[FirecrawlDocumentData]: + """Collect all crawl result pages by following pagination links. + + Raises an exception if any paginated request fails, to avoid returning + partial data that is inconsistent with the reported total. + + The number of pages processed is capped at ``total`` (the + server-reported page count) to guard against infinite loops caused by + a misbehaving server that keeps returning a ``next`` URL. + """ + total: int = first_page.get("total") or 0 + url_data_list: list[FirecrawlDocumentData] = [] + current_page = first_page + pages_processed = 0 + while True: + for item in current_page.get("data", []): + if isinstance(item, dict) and "metadata" in item and "markdown" in item: + url_data_list.append(self._extract_common_fields(item)) + next_url: str | None = current_page.get("next") + pages_processed += 1 + if not next_url or pages_processed >= total: + break + response = self._get_request(next_url, headers) + if response.status_code != 200: + self._handle_error(response, "fetch next crawl page") + current_page = response.json() + return url_data_list + def _format_crawl_status_response( self, status: str, diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index d9145023ac..a6d1db214b 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview @@ -159,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b1707..9f36b7a225 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,11 +8,12 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword @@ -22,23 +23,24 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols -from dify_graph.file import File, FileTransferMethod, FileType, file_manager -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories.file_factory import build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from extensions.ext_database import db -from factories.file_factory import build_from_mapping +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account @@ -48,6 +50,8 @@ from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService +_file_access_controller = DatabaseFileAccessController() + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -117,7 +121,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -155,7 +159,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -253,12 +257,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) @@ -410,7 +414,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # If default prompt doesn't have {language} placeholder, use it as-is pass - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id, model_provider_name, ModelType.LLM ) @@ -555,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): file_obj = build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) file_objects.append(file_obj) except Exception as e: @@ -604,11 +609,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, ) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index df0761ca73..70504e6e50 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 62f88b7760..6874603a83 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index dc3b771406..4ebf095904 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.file import File +from graphon.file import File class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 88acb75133..cc65262527 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -12,7 +12,6 @@ class BaseRerankRunner(ABC): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -21,7 +20,6 @@ class BaseRerankRunner(ABC): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fcb14ffc52..6c6b077cc2 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -5,10 +5,10 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult from models.model import UploadFile @@ -22,7 +22,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -31,10 +30,11 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + ) is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, @@ -43,12 +43,12 @@ class RerankModelRunner(BaseRerankRunner): ) if not is_support_vision: if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) else: return documents else: rerank_result, unique_documents = self.fetch_multimodal_rerank( - query, documents, score_threshold, top_n, user, query_type + query, documents, score_threshold, top_n, query_type ) rerank_documents = [] @@ -73,7 +73,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> tuple[RerankResult, list[Document]]: """ Fetch text rerank @@ -81,7 +80,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ docs = [] @@ -103,7 +101,7 @@ class RerankModelRunner(BaseRerankRunner): unique_documents.append(document) rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents @@ -113,7 +111,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> tuple[RerankResult, list[Document]]: """ @@ -122,7 +119,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :param query_type: query type :return: rerank result """ @@ -168,7 +164,7 @@ class RerankModelRunner(BaseRerankRunner): documents = unique_documents if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) return rerank_result, unique_documents elif query_type == QueryType.IMAGE_QUERY: # Query file info within db.session context to ensure thread-safe access @@ -181,7 +177,7 @@ class RerankModelRunner(BaseRerankRunner): "content_type": DocType.IMAGE, } rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 7edd05d2d1..d0732b269a 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -11,7 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): @@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ @@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner): """ query_vector_scores = [] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=tenant_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a5..49b91707ec 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, @@ -63,13 +64,14 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( SourceChildChunk, SourceMetadata, ) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile from models.dataset import ( @@ -160,7 +162,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -383,23 +385,27 @@ class DatasetRetrieval: return None, [] retrieve_config = config.retrieve_config - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - # get model schema + # Reuse the caller-bound model instance for both schema resolution and + # downstream planner/invoke calls so a single request never mixes + # tenant-scope and request-bound runtimes. model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials + model=model_instance.model_name, + credentials=model_instance.credentials, ) if not model_schema: return None, [] + model_config.provider_model_bundle = model_instance.provider_model_bundle + model_config.credentials = model_instance.credentials + model_config.model_schema = model_schema + planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: @@ -517,11 +523,12 @@ class DatasetRetrieval: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=segment.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), @@ -675,7 +682,7 @@ class DatasetRetrieval: # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if selected_dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -752,7 +759,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -986,6 +993,24 @@ class DatasetRetrieval: ) ) + @staticmethod + def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None: + """Map runtime user source values to dataset query audit roles. + + Workflow run context uses the hyphenated ``end-user`` value, while + ``DatasetQuery.created_by_role`` persists the underscore-based + ``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an + unsupported value should be skipped instead of aborting retrieval. + """ + normalized_user_from = str(user_from).strip().lower().replace("-", "_") + if normalized_user_from == CreatorUserRole.ACCOUNT.value: + return CreatorUserRole.ACCOUNT + if normalized_user_from == CreatorUserRole.END_USER.value: + return CreatorUserRole.END_USER + + logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from) + return None + def _on_query( self, query: str | None, @@ -996,10 +1021,18 @@ class DatasetRetrieval: user_id: str, ): """ - Handle query. + Persist dataset query audit rows for retrieval requests. """ if not query and not attachment_ids: return + created_by = parse_uuid_str_or_none(user_id) + if created_by is None: + logger.debug( + "Skipping dataset query log: empty created_by user_id (user_from=%s, app_id=%s)", + user_from, + app_id, + ) + return dataset_queries = [] for dataset_id in dataset_ids: contents = [] @@ -1015,7 +1048,7 @@ class DatasetRetrieval: source=DatasetQuerySource.APP, source_app_id=app_id, created_by_role=CreatorUserRole(user_from), - created_by=user_id, + created_by=created_by, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -1068,7 +1101,7 @@ class DatasetRetrieval: else default_retrieval_model ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -1411,7 +1444,7 @@ class DatasetRetrieval: raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id) # fetch prompt messages prompt_messages, stop = self._get_prompt_template( @@ -1430,7 +1463,6 @@ class DatasetRetrieval: model_parameters=model_config.parameters, stop=stop, stream=True, - user=user_id, ), ) @@ -1533,7 +1565,7 @@ class DatasetRetrieval: return filters def _fetch_model_config( - self, tenant_id: str, model: ModelConfig + self, tenant_id: str, model: ModelConfig, user_id: str | None = None ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config @@ -1543,7 +1575,7 @@ class DatasetRetrieval: model_name = model.name provider_name = model.provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 23a2ac8386..e617a9660e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,8 +2,8 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index ea110fa0a7..83e58fe0f9 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -3,13 +3,14 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -119,6 +120,7 @@ class ReactMultiDatasetRouter: memory_config=None, memory=None, model_config=model_config, + model_instance=model_instance, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -150,19 +152,24 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( + bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) + invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=completion_param, stop=stop, stream=True, - user=user_id, ) # handle invoke result text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage) return text, usage diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 7a00e8a886..2c27ac3cf6 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -15,7 +15,7 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 31d21dbeee..6f120bd471 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService @@ -21,7 +22,7 @@ class SummaryIndex: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..cfa9962ea8 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -4,7 +4,13 @@ from __future__ import annotations from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .factory import ( + DifyCoreRepositoryFactory, + OrderConfig, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -12,7 +18,10 @@ __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", + "OrderConfig", "RepositoryImportError", "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", + "WorkflowExecutionRepository", + "WorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 57764574d7..d0164b76dc 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -11,8 +11,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.workflow_execution import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 650cf79550..52361cf6dc 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -12,11 +12,11 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.repositories.workflow_node_execution_repository import ( +from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # For now, we'll re-raise the exception raise - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + Retrieve all workflow node executions for a workflow execution from cache. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results Returns: A sequence of WorkflowNodeExecution instances """ try: - # Get execution IDs for this workflow run from cache - execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + # Get execution IDs for this workflow execution from cache + execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, []) # Retrieve executions from cache result = [] @@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): for field_name in reversed(order_config.order_by): result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) - logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + logger.debug( + "Retrieved %d workflow node executions for execution %s from cache", + len(result), + workflow_execution_id, + ) return result except Exception: - logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + logger.exception( + "Failed to get workflow node executions for execution %s from cache", + workflow_execution_id, + ) return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dc9f8c96bf..dafdbf641a 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -from typing import Union +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Protocol, Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom +@dataclass +class OrderConfig: + """Configuration for ordering node execution instances.""" + + order_by: list[str] + order_direction: Literal["asc", "desc"] | None = None + + +class WorkflowExecutionRepository(Protocol): + def save(self, execution: WorkflowExecution): ... + + +class WorkflowNodeExecutionRepository(Protocol): + def save(self, execution: WorkflowNodeExecution): ... + + def save_execution_data(self, execution: WorkflowNodeExecution): ... + + def get_by_workflow_execution( + self, + workflow_execution_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: ... + + class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 6607a87032..02625e242f 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -2,33 +2,23 @@ import dataclasses import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Protocol from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( + BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, + InteractiveSurfaceDeliveryMethod, + is_human_input_webapp_enabled, ) +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin @@ -36,6 +26,7 @@ from models.human_input import ( BackstageRecipientPayload, ConsoleDeliveryPayload, ConsoleRecipientPayload, + DeliveryMethodType, EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, @@ -58,6 +49,65 @@ class _WorkspaceMemberInfo: email: str +class FormNotFoundError(Exception): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str | None + node_id: str + form_config: HumanInputNodeData + rendered_content: str + delivery_methods: Sequence[DeliveryChannelConfig] + display_in_ui: bool + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + +class HumanInputFormRecipientEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def token(self) -> str: ... + + +class HumanInputFormEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def submission_token(self) -> str | None: ... + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... + + +class HumanInputFormRepository(Protocol): + def get_form(self, node_id: str) -> HumanInputFormEntity | None: ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ... + + class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): def __init__(self, recipient_model: HumanInputFormRecipient): self._recipient_model = recipient_model @@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( + self._interactive_surface_recipient = next( ( recipient for recipient in recipient_models @@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._form_model.id @property - def web_app_token(self): + def submission_token(self) -> str | None: if self._console_recipient is not None: return self._console_recipient.access_token - if self._web_app_recipient is None: + if self._interactive_surface_recipient is None: return None - return self._web_app_recipient.access_token + return self._interactive_surface_recipient.access_token @property def recipients(self) -> list[HumanInputFormRecipientEntity]: @@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl: self, *, tenant_id: str, - ): + app_id: str | None = None, + workflow_execution_id: str | None = None, + invoke_source: str | None = None, + submission_actor_id: str | None = None, + ) -> None: self._tenant_id = tenant_id + self._app_id = app_id + self._workflow_execution_id = workflow_execution_id + self._invoke_source = invoke_source + self._submission_actor_id = submission_actor_id def _delivery_method_to_model( self, @@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): + if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, @@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl: delivery_id: str, recipients_config: EmailRecipients, ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + bound_reference_ids = [ + recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient) ] external_emails = [ recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) ] - if recipients_config.whole_workspace: + if recipients_config.include_bound_group: members = self._query_all_workspace_members(session=session) else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids) return self._create_email_recipients_from_resolved( form_id=form_id, @@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl: rows = session.execute(stmt).all() return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def _should_create_console_recipient( + self, + *, + form_config: HumanInputNodeData, + form_kind: HumanInputFormKind, + ) -> bool: + if form_kind != HumanInputFormKind.RUNTIME: + return False + if self._invoke_source == "debugger": + return True + if self._invoke_source == "explore": + return is_human_input_webapp_enabled(form_config) + return False + + def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool: + return form_kind == HumanInputFormKind.RUNTIME and ( + self._invoke_source is not None or self._submission_actor_id is not None + ) + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config + app_id = self._app_id + if not app_id: + raise ValueError("app_id is required to create a human input form") + workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id + if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None: + raise ValueError("workflow_execution_id is required for runtime human input forms") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl: form_model = HumanInputForm( id=form_id, tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, + app_id=app_id, + workflow_run_id=workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, form_definition=form_definition.model_dump_json(), @@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl: session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( + if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models ): console_delivery_id = str(uuidv7()) @@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl: delivery_id=console_delivery_id, recipient_type=RecipientType.CONSOLE, recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(console_delivery) session.add(console_recipient) recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( + if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models ): backstage_delivery_id = str(uuidv7()) @@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl: delivery_id=backstage_delivery_id, recipient_type=RecipientType.BACKSTAGE, recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(backstage_delivery) @@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + if self._workflow_execution_id is None: + raise ValueError("workflow_execution_id is required to load runtime human input forms") + form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.workflow_run_id == self._workflow_execution_id, HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 55e96515ac..1ee5d4ae77 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -9,10 +9,10 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7373ebc7cc..749ab44a14 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -17,12 +17,12 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( @@ -518,29 +518,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return db_models - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. This method always queries the database to ensure complete and ordered results, but updates the cache with any retrieved executions. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of node execution instances """ - # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) + db_models = self.get_db_models_by_workflow_run(workflow_execution_id, order_config, triggered_from) with ThreadPoolExecutor(max_workers=10) as executor: domain_models = executor.map(self._to_domain_model, db_models, timeout=30) diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 0000000000..ae4f53f3b7 --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,43 @@ +"""Telemetry facade. + +Thin public API for emitting telemetry events. All routing logic +lives in ``core.telemetry.gateway`` which is shared by both CE and EE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent +from core.telemetry.gateway import emit as gateway_emit +from core.telemetry.gateway import get_trace_task_to_case + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + """Emit a telemetry event. + + Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a + ``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``. + """ + case = get_trace_task_to_case().get(event.name) + if case is None: + return + + context: dict[str, object] = { + "tenant_id": event.context.tenant_id, + "user_id": event.context.user_id, + "app_id": event.context.app_id, + } + gateway_emit(case, context, event.payload, trace_manager) + + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 0000000000..35ace47510 --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py new file mode 100644 index 0000000000..7b013d0563 --- /dev/null +++ b/api/core/telemetry/gateway.py @@ -0,0 +1,239 @@ +"""Telemetry gateway — single routing layer for all editions. + +Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either +the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only +metric/log Celery queue. + +This module lives in ``core/`` so both CE and EE share one routing table +and one ``emit()`` entry point. No separate enterprise gateway module is +needed — enterprise-specific dispatch (Celery task, payload offloading) +is handled here behind lazy imports that no-op in CE. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import CaseRoute, SignalType +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + from enterprise.telemetry.contracts import TelemetryCase + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +# --------------------------------------------------------------------------- +# Routing table — authoritative mapping for all editions +# --------------------------------------------------------------------------- + +_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None +_case_routing: dict[TelemetryCase, CaseRoute] | None = None + + +def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]: + global _case_to_trace_task + if _case_to_trace_task is None: + from enterprise.telemetry.contracts import TelemetryCase + + _case_to_trace_task = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, + TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE, + TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE, + TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE, + TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE, + TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE, + } + return _case_to_trace_task + + +def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]: + """Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task).""" + return {v: k for k, v in _get_case_to_trace_task().items()} + + +def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: + global _case_routing + if _case_routing is None: + from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase + + _case_routing = { + # TRACE — CE-eligible (flow in both CE and EE) + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + # TRACE — enterprise-only + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + # METRIC_LOG — enterprise-only (signal-driven, not trace) + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + } + return _case_routing + + +def __getattr__(name: str) -> dict: + """Lazy module-level access to routing tables.""" + if name == "CASE_ROUTING": + return _get_case_routing() + if name == "CASE_TO_TRACE_TASK": + return _get_case_to_trace_task() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _handle_payload_sizing( + payload: dict[str, Any], + tenant_id: str, + event_id: str, +) -> tuple[dict[str, Any], str | None]: + """Inline or offload payload based on size. + + Returns ``(payload_for_envelope, storage_key | None)``. Payloads + exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object + storage and replaced with an empty dict in the envelope. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Route a telemetry event to the correct pipeline. + + TRACE events are enqueued into ``TraceQueueManager`` (works in both CE + and EE). Enterprise-only traces are silently dropped when EE is + disabled. + + METRIC_LOG events are dispatched to the enterprise Celery queue; + silently dropped when enterprise telemetry is unavailable. + """ + route = _get_case_routing().get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if not route.ce_eligible and not is_enterprise_telemetry_enabled(): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + if route.signal_type == SignalType.TRACE: + _emit_trace(case, context, payload, trace_manager) + else: + _emit_metric_log(case, context, payload) + + +def _emit_trace( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, +) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = _get_case_to_trace_task().get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + +def _emit_metric_log( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], +) -> None: + """Build envelope and dispatch to enterprise Celery queue. + + No-ops when the enterprise telemetry task is not importable (CE mode). + """ + try: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + except ImportError: + logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case) + return + + tenant_id = context.get("tenant_id") or "" + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id) + + from enterprise.telemetry.contracts import TelemetryEnvelope + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 961d13f90a..5154bc9805 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing. + + ``user_id`` is optional so read-only tooling flows can stay tenant-scoped, + while execution paths may bind caller identity for model runtime lookups. """ tenant_id: str + user_id: str | None = None tool_id: str | None = None invoke_from: InvokeFrom | None = None tool_invoke_from: ToolInvokeFrom | None = None diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index dacc49c746..40bf2e98c2 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -7,9 +7,9 @@ from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.file.enums import FileType -from dify_graph.file.file_manager import download -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.file.enums import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService @@ -22,6 +22,9 @@ class ASRTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: + if not self.runtime: + raise ValueError("Runtime is required") + runtime = self.runtime file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore yield self.create_text_message("not a valid audio file") @@ -29,20 +32,19 @@ class ASRTool(BuiltinTool): audio_binary = io.BytesIO(download(file)) # type: ignore audio_binary.name = "temp.mp3" provider, model = tool_parameters.get("model").split("#") # type: ignore - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=runtime.tenant_id, provider=provider, model_type=ModelType.SPEECH2TEXT, model=model, ) - text = model_instance.invoke_speech2text( - file=audio_binary, - user=user_id, - ) + text = model_instance.invoke_speech2text(file=audio_binary) yield self.create_text_message(text) def get_available_models(self) -> list[tuple[str, str]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type( tenant_id=self.runtime.tenant_id, model_type="speech2text" diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 7818bff0ab..ac3820f1ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -7,7 +7,7 @@ from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService @@ -20,13 +20,14 @@ class TTSTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - provider, model = tool_parameters.get("model").split("#") # type: ignore - voice = tool_parameters.get(f"voice#{provider}#{model}") - model_manager = ModelManager() if not self.runtime: raise ValueError("Runtime is required") + runtime = self.runtime + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id or "", + tenant_id=runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -39,12 +40,7 @@ class TTSTool(BuiltinTool): raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), # type: ignore - user=user_id, - tenant_id=self.runtime.tenant_id, - voice=voice, - ) + tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type] buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index 44f94c2723..e07ca0d919 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import UTC, datetime from typing import Any -from pytz import timezone as pytz_timezone +from pytz import timezone as pytz_timezone # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index d0a41b940f..dc49b64dd8 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 462e4be5ce..8045e4b980 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index e23ae3b001..e2570811d6 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 00f5931088..d41503e1e6 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -4,8 +4,8 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but @@ -50,9 +50,10 @@ class BuiltinTool(Tool): return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id or "", - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, + caller_user_id=self.runtime.user_id, ) def tool_provider_type(self) -> ToolProviderType: @@ -69,6 +70,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", + user_id=self.runtime.user_id, ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -82,7 +84,9 @@ class BuiltinTool(Tool): raise ValueError("runtime is required") return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + tenant_id=self.runtime.tenant_id or "", + prompt_messages=prompt_messages, + user_id=self.runtime.user_id, ) def summary(self, user_id: str, content: str) -> str: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index c6a84e27c6..168e5f4493 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -13,7 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from dify_graph.file.file_manager import download +from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2545290b57..08640befb4 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -9,7 +9,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 9025ff6ef1..00fc8a8282 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -21,7 +21,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 22e099deba..1807226924 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -3,6 +3,7 @@ import hashlib import hmac import os import time +import urllib.parse from configs import dify_config @@ -58,3 +59,43 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: + """Build the signed upload URL used by the plugin-facing file upload endpoint.""" + + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + upload_url = f"{base_url}/files/upload/for-plugin" + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": encoded_sign, + "user_id": user_id, + "tenant_id": tenant_id, + } + ) + return f"{upload_url}?{query}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str +) -> bool: + """Verify the signature used by the plugin-facing file upload endpoint.""" + + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 64212a2636..1fd259f3bb 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -31,9 +31,9 @@ from core.tools.errors import ( ) from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FileType -from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db +from graphon.file import FileType +from graphon.file.models import FileTransferMethod from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 210f488afc..2ec292602c 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -14,8 +14,9 @@ import httpx from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from dify_graph.file.models import ToolFile as ToolFilePydanticModel +from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile @@ -23,6 +24,21 @@ logger = logging.getLogger(__name__) class ToolFileManager: + @staticmethod + def _build_graph_file_reference(tool_file: ToolFile) -> File: + extension = guess_extension(tool_file.mimetype) or ".bin" + return File( + type=get_file_type_by_mime_type(tool_file.mimetype), + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + filename=tool_file.name, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -209,9 +225,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id( - self, tool_file_id: str - ) -> tuple[Generator | None, ToolFilePydanticModel | None]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: """ get file binary @@ -233,11 +247,11 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, ToolFilePydanticModel.model_validate(tool_file) + return stream, self._build_graph_file_reference(tool_file) # init tool_file_parser -from dify_graph.file.tool_file_parser import set_tool_file_manager_factory +from graphon.file.tool_file_parser import set_tool_file_manager_factory def _factory() -> ToolFileManager: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 90d5a647e9..250dd91bfd 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -38,7 +38,7 @@ class ToolLabelManager: db.session.add( ToolLabelBinding( tool_id=provider_id, - tool_type=controller.provider_type.value, + tool_type=controller.provider_type, label_name=label, ) ) @@ -58,7 +58,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") stmt = select(ToolLabelBinding.label_name).where( ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, + ToolLabelBinding.tool_type == controller.provider_type, ) labels = db.session.scalars(stmt).all() diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 23a877b7e3..4870adb7b5 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa from sqlalchemy import select @@ -24,14 +24,14 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_database import db +from graphon.runtime.variable_pool import VariablePool from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -57,12 +57,12 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass logger = logging.getLogger(__name__) @@ -77,6 +77,23 @@ class EmojiIconDict(TypedDict): content: str +class WorkflowToolRuntimeSpec(Protocol): + @property + def provider_type(self) -> ToolProviderType: ... + + @property + def provider_id(self) -> str: ... + + @property + def tool_name(self) -> str: ... + + @property + def tool_configurations(self) -> Mapping[str, Any]: ... + + @property + def credential_id(self) -> str | None: ... + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -167,6 +184,7 @@ class ToolManager: provider_id: str, tool_name: str, tenant_id: str, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, @@ -178,6 +196,7 @@ class ToolManager: :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id + :param user_id: the caller id bound to runtime-scoped model/tool lookups :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id @@ -196,6 +215,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -304,6 +324,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(decrypted_credentials), credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, @@ -321,6 +342,7 @@ class ToolManager: return api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(encrypter.decrypt(credentials)), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -344,6 +366,7 @@ class ToolManager: return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -352,9 +375,21 @@ class ToolManager: elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool elif provider_type == ToolProviderType.MCP: - return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -364,6 +399,7 @@ class ToolManager: tenant_id: str, app_id: str, agent_tool: AgentToolEntity, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -375,6 +411,7 @@ class ToolManager: provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, @@ -405,7 +442,8 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: WorkflowToolRuntimeSpec, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -418,6 +456,7 @@ class ToolManager: provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, @@ -450,6 +489,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + user_id: str | None = None, credential_id: str | None = None, ) -> Tool: """ @@ -460,6 +500,7 @@ class ToolManager: provider_id=provider, tool_name=tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, @@ -1015,14 +1056,14 @@ class ToolManager: cls, parameters: list[ToolParameter], variable_pool: Optional["VariablePool"], - tool_configurations: dict[str, Any], + tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: """ Convert tool parameters type """ - from dify_graph.nodes.tool.entities import ToolNodeData - from dify_graph.nodes.tool.exc import ToolParameterError + from graphon.nodes.tool.entities import ToolNodeData + from graphon.nodes.tool.exc import ToolParameterError runtime_parameters = {} for parameter in parameters: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..dad5133a7a 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { @@ -65,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): for thread in threads: thread.join() # do rerank for searched documents - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) rerank_model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider=self.reranking_provider_name, @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 429b7e6622..f3d390ed59 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fc5fead2d..5cf46b2564 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Generator from datetime import date, datetime from decimal import Decimal @@ -10,12 +11,15 @@ import pytz from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account logger = logging.getLogger(__name__) +_TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P[^/?#.]+)") + def safe_json_value(v): if isinstance(v, datetime): @@ -82,11 +86,15 @@ class ToolFileMessageTransformer: ) url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" + meta = cls._with_tool_file_meta( + message.meta, + tool_file_id=str(tool_file.id), + ) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, + meta=meta, ) except Exception as e: yield ToolInvokeMessage( @@ -122,38 +130,45 @@ class ToolFileMessageTransformer: ) url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) + meta = cls._with_tool_file_meta(meta, tool_file_id=str(tool_file.id)) # check if file is image if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("tool file is missing reference") + url = cls.get_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + ) + tool_file_meta = cls._with_tool_file_meta(meta, tool_file_id=parsed_reference.record_id) if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield message @@ -162,9 +177,40 @@ class ToolFileMessageTransformer: if isinstance(message.message, ToolInvokeMessage.JsonMessage): message.message.json_object = safe_json_value(message.message.json_object) yield message + elif message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + } and isinstance(message.message, ToolInvokeMessage.TextMessage): + yield ToolInvokeMessage( + type=message.type, + message=message.message, + meta=cls._with_tool_file_meta(message.meta, url=message.message.text), + ) else: yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" + + @staticmethod + def _with_tool_file_meta( + meta: dict | None, + *, + tool_file_id: str | None = None, + url: str | None = None, + ) -> dict: + normalized_meta = meta.copy() if meta is not None else {} + resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) + if resolved_tool_file_id and "tool_file_id" not in normalized_meta: + normalized_meta["tool_file_id"] = resolved_tool_file_id + return normalized_meta + + @staticmethod + def _extract_tool_file_id(url: str | None) -> str | None: + if not url: + return None + match = _TOOL_FILE_URL_PATTERN.search(url) + if match is None: + return None + return match.group("tool_file_id") diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8f958563bd..9e1d41cb39 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -9,19 +9,20 @@ from decimal import Decimal from typing import cast from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.errors.invoke import ( +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ToolModelInvoke @@ -33,11 +34,12 @@ class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, + user_id: str | None = None, ) -> int: """ get max llm context tokens of the model """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -59,13 +61,13 @@ class ModelInvocationUtils: return max_tokens @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ calculate tokens from prompt messages and model parameters """ # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -78,7 +80,12 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, + tenant_id: str, + tool_type: ToolProviderType, + tool_name: str, + prompt_messages: list[PromptMessage], + caller_user_id: str | None = None, ) -> LLMResult: """ invoke model with parameters in user's own context @@ -92,7 +99,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, @@ -136,7 +143,6 @@ class ModelInvocationUtils: tools=[], stop=[], stream=False, - user=user_id, callbacks=[], ) except InvokeRateLimitError as e: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 28f1376655..1e4f3ed2a7 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,9 +3,9 @@ from typing import Any from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.variables.input_entities import VariableEntity +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.variables.input_entities import VariableEntity class WorkflowToolConfigurationUtils: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index aef8b3f779..716368c191 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -22,8 +22,8 @@ from core.tools.entities.tool_entities import ( ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9b9aa7a741..495fcd48b3 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,6 +7,7 @@ from typing import Any, cast from sqlalchemy import select +from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -17,14 +18,17 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WorkflowTool(Tool): @@ -288,16 +292,25 @@ class WorkflowTool(Tool): file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [File.model_validate(f) for f in file] + file_var_list = [ + build_file_from_stored_mapping( + file_mapping=cast(Mapping[str, Any], f), + tenant_id=str(self.runtime.tenant_id), + ) + for f in file + if isinstance(f, Mapping) + ] for file in file_var_list: file_dict: dict[str, str | None] = { "transfer_method": file.transfer_method.value, "type": file.type.value, } if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file.related_id + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file.related_id + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() @@ -325,6 +338,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=item, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: @@ -332,6 +346,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=value, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) @@ -340,9 +355,10 @@ class WorkflowTool(Tool): return result, files def _update_file_mapping(self, file_dict: dict): + file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_dict.get("related_id") + file_dict["tool_file_id"] = file_id elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_dict.get("related_id") + file_dict["upload_file_id"] = file_id return file_dict diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 2a133b2b94..24c1271488 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -26,8 +26,8 @@ from core.trigger.debug.events import ( ) from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/file_reference.py b/api/core/workflow/file_reference.py new file mode 100644 index 0000000000..c80acb3783 --- /dev/null +++ b/api/core/workflow/file_reference.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + + +@dataclass(frozen=True) +class FileReference: + record_id: str + storage_key: str | None = None + + +def build_file_reference(*, record_id: str, storage_key: str | None = None) -> str: + payload = {"record_id": record_id} + if storage_key is not None: + payload["storage_key"] = storage_key + encoded_payload = base64.urlsafe_b64encode(json.dumps(payload, separators=(",", ":")).encode()).decode() + return f"{_FILE_REFERENCE_PREFIX}{encoded_payload}" + + +def parse_file_reference(reference: str | None) -> FileReference | None: + if not reference: + return None + + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return FileReference(record_id=reference) + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return FileReference(record_id=reference) + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return FileReference(record_id=reference) + + storage_key = payload.get("storage_key") + if storage_key is not None and not isinstance(storage_key, str): + storage_key = None + + return FileReference(record_id=record_id, storage_key=storage_key) + + +def resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return None + return parsed_reference.record_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py new file mode 100644 index 0000000000..75a0a0c202 --- /dev/null +++ b/api/core/workflow/human_input_compat.py @@ -0,0 +1,299 @@ +"""Workflow-layer adapters for legacy human-input payload keys. + +Stored workflow graphs and editor payloads may still use Dify-specific human +input recipient keys. Normalize them here before handing configs to +`graphon` so graph-owned models only see graph-neutral field names. +""" + +from __future__ import annotations + +import enum +import uuid +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ClassVar, Literal + +import bleach +import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.runtime import VariablePool +from graphon.variables.consts import SELECTORS_LENGTH + + +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + BOUND = "member" + MEMBER = BOUND + EXTERNAL = "external" + + +class _InteractiveSurfaceDeliveryConfig(BaseModel): + pass + + +class BoundRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.BOUND] = EmailRecipientType.BOUND + reference_id: str + + +class ExternalRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +MemberRecipient = BoundRecipient +EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + model_config = ConfigDict(extra="forbid") + + include_bound_group: bool = Field( + default=False, + validation_alias=AliasChoices("include_bound_group", "whole_workspace"), + ) + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "br", + "code", + "em", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[set[str]] = set(bleach.sanitizer.ALLOWED_PROTOCOLS) | {"mailto"} + + recipients: EmailRecipients + subject: str + body: str + debug_mode: bool = False + + def with_recipients(self, recipients: EmailRecipients) -> EmailDeliveryConfig: + return self.model_copy(update={"recipients": recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + @classmethod + def render_markdown_body(cls, body: str) -> str: + stripped_body = bleach.clean(body, tags=[], attributes={}, strip=True) + rendered = markdown.markdown( + stripped_body, + extensions=[TableExtension(use_align_attribute=True)], + output_format="html", + ) + return bleach.clean( + rendered, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + ) + + @staticmethod + def sanitize_subject(subject: str) -> str: + sanitized = subject.replace("\r", " ").replace("\n", " ") + sanitized = bleach.clean(sanitized, tags=[], strip=True) + return " ".join(sanitized.split()) + + +class _DeliveryMethodBase(BaseModel): + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class InteractiveSurfaceDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + config: _InteractiveSurfaceDeliveryConfig = Field(default_factory=_InteractiveSurfaceDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +WebAppDeliveryMethod = InteractiveSurfaceDeliveryMethod +_WebAppDeliveryConfig = _InteractiveSurfaceDeliveryConfig + +DeliveryChannelConfig = Annotated[InteractiveSurfaceDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + +_DELIVERY_METHODS_ADAPTER = TypeAdapter(list[DeliveryChannelConfig]) + + +def _copy_mapping(value: object) -> dict[str, Any] | None: + if isinstance(value, BaseModel): + return value.model_dump(mode="python") + if isinstance(value, Mapping): + return dict(value) + return None + + +def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") + + delivery_methods = normalized.get("delivery_methods") + if not isinstance(delivery_methods, list): + return normalized + + normalized_methods: list[Any] = [] + for method in delivery_methods: + method_mapping = _copy_mapping(method) + if method_mapping is None: + normalized_methods.append(method) + continue + + config_mapping = _copy_mapping(method_mapping.get("config")) + if config_mapping is not None: + recipients_mapping = _copy_mapping(config_mapping.get("recipients")) + if recipients_mapping is not None: + config_mapping["recipients"] = _normalize_email_recipients(recipients_mapping) + method_mapping["config"] = config_mapping + + normalized_methods.append(method_mapping) + + normalized["delivery_methods"] = normalized_methods + return normalized + + +def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: + normalized = normalize_human_input_node_data_for_graph(node_data) + raw_delivery_methods = normalized.get("delivery_methods") + if not isinstance(raw_delivery_methods, list): + return [] + return list(_DELIVERY_METHODS_ADAPTER.validate_python(raw_delivery_methods)) + + +def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> bool: + for method in parse_human_input_delivery_methods(node_data): + if method.enabled and method.type == DeliveryMethodType.WEBAPP: + return True + return False + + +def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") + + if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: + return normalized + return normalize_human_input_node_data_for_graph(normalized) + + +def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_config) + if normalized is None: + raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") + + data_mapping = _copy_mapping(normalized.get("data")) + if data_mapping is None: + return normalized + + normalized["data"] = normalize_node_data_for_graph(data_mapping) + return normalized + + +def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(recipients) + + legacy_include_bound_group = normalized.pop("whole_workspace", None) + if "include_bound_group" not in normalized and legacy_include_bound_group is not None: + normalized["include_bound_group"] = legacy_include_bound_group + + items = normalized.get("items") + if not isinstance(items, list): + return normalized + + normalized_items: list[Any] = [] + for item in items: + item_mapping = _copy_mapping(item) + if item_mapping is None: + normalized_items.append(item) + continue + + legacy_reference_id = item_mapping.pop("user_id", None) + if "reference_id" not in item_mapping and legacy_reference_id is not None: + item_mapping["reference_id"] = legacy_reference_id + normalized_items.append(item_mapping) + + normalized["items"] = normalized_items + return normalized + + +__all__ = [ + "BoundRecipient", + "DeliveryChannelConfig", + "DeliveryMethodType", + "EmailDeliveryConfig", + "EmailDeliveryMethod", + "EmailRecipientType", + "EmailRecipients", + "ExternalRecipient", + "MemberRecipient", + "WebAppDeliveryMethod", + "_WebAppDeliveryConfig", + "is_human_input_webapp_enabled", + "normalize_human_input_node_data_for_graph", + "normalize_node_config_for_graph", + "normalize_node_data_for_graph", + "parse_human_input_delivery_methods", +] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py new file mode 100644 index 0000000000..f124b321d4 --- /dev/null +++ b/api/core/workflow/human_input_forms.py @@ -0,0 +1,55 @@ +"""Shared helpers for workflow pause-time human input form lookups. + +Both controllers and streaming response converters need the same recipient +priority when exposing resume links for paused human input forms. Keep that +selection logic here so all API surfaces stay consistent. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.human_input import HumanInputFormRecipient, RecipientType + +_FORM_TOKEN_PRIORITY = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def load_form_tokens_by_form_id( + form_ids: Sequence[str], + *, + session: Session | None = None, +) -> dict[str, str]: + """Load the preferred access token for each human input form.""" + unique_form_ids = list(dict.fromkeys(form_ids)) + if not unique_form_ids: + return {} + + if session is not None: + return _load_form_tokens_by_form_id(session, unique_form_ids) + + with Session(bind=db.engine, expire_on_commit=False) as new_session: + return _load_form_tokens_by_form_id(new_session, unique_form_ids) + + +def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: + tokens_by_form_id: dict[str, tuple[int, str]] = {} + stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(stmt): + priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) + if priority is None or not recipient.access_token: + continue + + candidate = (priority, recipient.access_token) + current = tokens_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + tokens_by_form_id[recipient.form_id] = candidate + + return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ab34263a79..028e38fbee 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -9,8 +9,8 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config -from core.app.entities.app_invoke_entities import DifyRunContext -from core.app.llm.model_access import build_dify_model_access +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -19,45 +19,48 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + DifyPreparedLLM, + DifyPromptMessageSerializer, + DifyRetrieverAttachmentLoader, + DifyToolFileManager, + DifyToolNodeRuntime, + build_dify_llm_file_saver, +) from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey -from dify_graph.file.file_manager import file_manager -from dify_graph.graph.graph import NodeFactory -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.nodes.document_extractor import UnstructuredApiConfig -from dify_graph.nodes.http_request import build_http_request_config -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import TemplateRenderer -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector +from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState LATEST_VERSION = "latest" _START_NODE_TYPES: frozenset[NodeType] = frozenset( @@ -76,7 +79,7 @@ def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] @lru_cache(maxsize=1) def register_nodes() -> None: """Import production node modules so they self-register with ``Node``.""" - _import_node_package("dify_graph.nodes") + _import_node_package("graphon.nodes") _import_node_package("core.workflow.nodes") @@ -84,7 +87,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] """Return a read-only snapshot of the current production node registry. The workflow layer owns node bootstrap because it must compose built-in - `dify_graph.nodes.*` implementations with workflow-local nodes under + `graphon.nodes.*` implementations with workflow-local nodes under `core.workflow.nodes.*`. Keeping this import side effect here avoids reintroducing registry bootstrapping into lower-level graph primitives. """ @@ -115,7 +118,7 @@ def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str: This workflow-layer helper depends on start-node semantics defined by `is_start_node_type`, so it intentionally lives next to the node registry - instead of in the raw `dify_graph.entities.graph_config` schema module. + instead of in the raw `graphon.entities.graph_config` schema module. """ nodes = graph_config.get("nodes") if not isinstance(nodes, list): @@ -229,16 +232,6 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) -class DefaultLLMTemplateRenderer(TemplateRenderer): - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=inputs, - ) - return str(result.get("result", "")) - - @final class DifyNodeFactory(NodeFactory): """ @@ -264,11 +257,31 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) - self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( + self._dify_context, + conversation_id_getter=self._conversation_id, + ) + self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) + self._prompt_message_serializer = DifyPromptMessageSerializer() + self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + ) + self._llm_file_saver = build_dify_llm_file_saver( + run_context=self._dify_context, + http_client=self._http_request_http_client, + conversation_id_getter=self._conversation_id, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime( + self._dify_context, + workflow_execution_id_getter=lambda: get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ), + ) + self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, @@ -284,7 +297,7 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport() @@ -299,6 +312,9 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) + def _conversation_id(self) -> str | None: + return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) + @override def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ @@ -310,7 +326,7 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) @@ -321,22 +337,29 @@ class DifyNodeFactory(NodeFactory): "code_limits": self._code_limits, }, BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { - "template_renderer": self._template_renderer, + "jinja2_template_renderer": self._jinja2_template_renderer, "max_output_length": self._template_transform_max_output_length, }, BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, - "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "tool_file_manager_factory": self._bound_tool_file_manager_factory, "file_manager": self._http_request_file_manager, + "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + "runtime": self._human_input_runtime, + "form_repository": self._human_input_runtime.build_form_repository(), }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=True, + include_jinja2_template_renderer=True, ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, @@ -345,15 +368,26 @@ class DifyNodeFactory(NodeFactory): BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, @@ -387,7 +421,12 @@ class DifyNodeFactory(NodeFactory): *, node_class: type[Node], node_data: BaseNodeData, + wrap_model_instance: bool, include_http_client: bool, + include_llm_file_saver: bool, + include_prompt_message_serializer: bool, + include_retriever_attachment_loader: bool, + include_jinja2_template_renderer: bool, ) -> dict[str, object]: validated_node_data = cast( LLMCompatibleNodeData, @@ -397,49 +436,35 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, "model_factory": self._llm_model_factory, - "model_instance": model_instance, + "model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance, "memory": self._build_memory_for_llm_node( node_data=validated_node_data, model_instance=model_instance, ), } - if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: - node_init_kwargs["template_renderer"] = self._llm_template_renderer + if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: + node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client + if include_llm_file_saver: + node_init_kwargs["llm_file_saver"] = self._llm_file_saver + if include_prompt_message_serializer: + node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer + if include_retriever_attachment_loader: + node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + if include_jinja2_template_renderer: + node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer + if validated_node_data.type == BuiltinNodeTypes.LLM: + node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, + model_instance, _ = fetch_model_config( + node_data_model=node_data_model, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance @@ -452,12 +477,7 @@ class DifyNodeFactory(NodeFactory): if node_data.memory is None: return None - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py new file mode 100644 index 0000000000..2e632e56f0 --- /dev/null +++ b/api/core/workflow/node_runtime.py @@ -0,0 +1,670 @@ +from __future__ import annotations + +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .human_input_compat import ( + BoundRecipient, + DeliveryChannelConfig, + DeliveryMethodType, + EmailDeliveryMethod, + EmailRecipients, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from .system_variables import SystemVariableKey, get_system_text + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + from graphon.file import File + from graphon.nodes.llm.file_saver import LLMFileSaver + from graphon.nodes.tool.entities import ToolNodeData + + +_file_access_controller = DatabaseFileAccessController() + + +def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext: + if isinstance(run_context, DifyRunContext): + return run_context + + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + +def apply_dify_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + actor_id: str | None, +) -> DeliveryChannelConfig: + """Apply the Dify debugger-specific email recipient override outside `graphon`.""" + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + + if actor_id is None: + debug_recipients = EmailRecipients(include_bound_group=False, items=[]) + else: + debug_recipients = EmailRecipients( + include_bound_group=False, + items=[BoundRecipient(reference_id=actor_id)], + ) + debug_config = method.config.with_recipients(debug_recipients) + return method.model_copy(update={"config": debug_config}) + + +class DifyFileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def build_from_mapping(self, *, mapping: Mapping[str, Any]): + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self._run_context.tenant_id, + access_controller=_file_access_controller, + ) + + +class DifyPreparedLLM(PreparedLLMProtocol): + """Workflow-layer adapter that hides the full `ModelInstance` API from `graphon` nodes.""" + + def __init__(self, model_instance: ModelInstance) -> None: + self._model_instance = model_instance + + @property + def provider(self) -> str: + return self._model_instance.provider + + @property + def model_name(self) -> str: + return self._model_instance.model_name + + @property + def parameters(self) -> Mapping[str, Any]: + return self._model_instance.parameters + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: + self._model_instance.parameters = value + + @property + def stop(self) -> Sequence[str] | None: + return self._model_instance.stop + + def get_model_schema(self) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( + self._model_instance.model_name, + self._model_instance.credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {self._model_instance.model_name}") + return model_schema + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: + return self._model_instance.get_llm_num_tokens(prompt_messages) + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + return self._model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=dict(model_parameters), + tools=list(tools or []), + stop=list(stop or []), + stream=stream, + ) + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + return invoke_llm_with_structured_output( + provider=self.provider, + model_schema=self.get_model_schema(), + model_instance=self._model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + stop=list(stop or []), + stream=stream, + ) + + def is_structured_output_parse_error(self, error: Exception) -> bool: + return isinstance(error, OutputParserError) + + +class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, + prompt_messages=prompt_messages, + ) + + +class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): + """Resolve retriever attachments through Dify persistence and return graph file references.""" + + def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + self._file_reference_factory = file_reference_factory + + def load(self, *, segment_id: str) -> Sequence[File]: + with Session(db.engine, expire_on_commit=False) as session: + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.segment_id == segment_id) + ).all() + + return [ + self._file_reference_factory.build_from_mapping( + mapping={ + "id": upload_file.id, + "filename": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "remote_url": upload_file.source_url, + "reference": build_file_reference(record_id=str(upload_file.id)), + "size": upload_file.size, + } + ) + for _, upload_file in attachments_with_bindings + ] + + +class DifyToolFileManager(ToolFileManagerProtocol): + """Workflow adapter that resolves conversation scope outside `graphon`.""" + + _conversation_id_getter: Callable[[], str | None] | None + + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + conversation_id_getter: Callable[[], str | None] | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._manager = ToolFileManager() + self._conversation_id_getter = conversation_id_getter + + def create_file_by_raw( + self, + *, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: + conversation_id = self._conversation_id_getter() if self._conversation_id_getter is not None else None + return self._manager.create_file_by_raw( + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=conversation_id, + file_binary=file_binary, + mimetype=mimetype, + filename=filename, + ) + + def get_file_generator_by_tool_file_id(self, tool_file_id: str): + return self._manager.get_file_generator_by_tool_file_id(tool_file_id) + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeSpec: + provider_type: CoreToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None = None + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeBinding: + """Workflow-private runtime state stored inside the opaque graph handle. + + The binding keeps conversation scope in `core.workflow` while `graphon` + continues to treat the handle as an opaque token. + """ + + tool: Tool + conversation_id: str | None = None + + +class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._file_reference_factory = DifyFileReferenceFactory(self._run_context) + + @property + def file_reference_factory(self) -> FileReferenceFactoryProtocol: + return self._file_reference_factory + + def build_file_reference(self, *, mapping: Mapping[str, Any]): + return self._file_reference_factory.build_from_mapping(mapping=mapping) + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool, + ) -> ToolRuntimeHandle: + try: + tool_runtime = ToolManager.get_workflow_tool_runtime( + self._run_context.tenant_id, + self._run_context.app_id, + node_id, + self._build_tool_runtime_spec(node_data), + self._run_context.user_id, + self._run_context.invoke_from, + variable_pool, + ) + except ToolNodeError: + raise + except Exception as exc: + raise ToolRuntimeResolutionError(str(exc)) from exc + + conversation_id = ( + None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + ) + return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: + tool = self._tool_from_handle(tool_runtime) + return [ + ToolRuntimeParameter(name=parameter.name, required=parameter.required) + for parameter in (tool.get_merged_runtime_parameters() or []) + ] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + runtime_binding = self._binding_from_handle(tool_runtime) + tool = runtime_binding.tool + callback = DifyWorkflowCallbackHandler() + + try: + messages = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=dict(tool_parameters), + user_id=self._run_context.user_id, + workflow_tool_callback=callback, + workflow_call_depth=workflow_call_depth, + app_id=self._run_context.app_id, + conversation_id=runtime_binding.conversation_id, + ) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=runtime_binding.conversation_id, + ) + + return self._adapt_messages(transformed_messages, provider_name=provider_name) + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: + latest = getattr(self._binding_from_handle(tool_runtime).tool, "latest_usage", None) + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + return LLMUsage.model_validate(latest) + return LLMUsage.empty_usage() + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: + icon: str | Mapping[str, str] | None = default_icon + icon_dark: str | Mapping[str, str] | None = None + + manager = PluginInstaller() + plugins = manager.list_plugins(self._run_context.tenant_id) + try: + current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name) + icon = current_plugin.declaration.icon + except StopIteration: + pass + + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self._run_context.user_id, + self._run_context.tenant_id, + ) + if provider.name == provider_name + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + return icon, icon_dark + + @staticmethod + def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: + return DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool + + @staticmethod + def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: + if isinstance(tool_runtime.raw, _WorkflowToolRuntimeBinding): + return tool_runtime.raw + return _WorkflowToolRuntimeBinding(tool=cast("Tool", tool_runtime.raw)) + + @staticmethod + def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + return _WorkflowToolRuntimeSpec( + provider_type=CoreToolProviderType(node_data.provider_type.value), + provider_id=node_data.provider_id, + tool_name=node_data.tool_name, + tool_configurations=dict(node_data.tool_configurations), + credential_id=node_data.credential_id, + ) + + def _adapt_messages( + self, + messages: Generator[CoreToolInvokeMessage, None, None], + *, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + try: + for message in messages: + yield self._convert_message(message) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage: + graph_message_type = ToolRuntimeMessage.MessageType(message.type.value) + graph_message = self._convert_message_payload(message.message) + graph_meta = message.meta.copy() if message.meta is not None else None + return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta) + + def _convert_message_payload( + self, + message: CoreToolInvokeMessage.TextMessage + | CoreToolInvokeMessage.JsonMessage + | CoreToolInvokeMessage.BlobChunkMessage + | CoreToolInvokeMessage.BlobMessage + | CoreToolInvokeMessage.LogMessage + | CoreToolInvokeMessage.FileMessage + | CoreToolInvokeMessage.VariableMessage + | CoreToolInvokeMessage.RetrieverResourceMessage + | None, + ) -> ( + ToolRuntimeMessage.TextMessage + | ToolRuntimeMessage.JsonMessage + | ToolRuntimeMessage.BlobChunkMessage + | ToolRuntimeMessage.BlobMessage + | ToolRuntimeMessage.LogMessage + | ToolRuntimeMessage.FileMessage + | ToolRuntimeMessage.VariableMessage + | ToolRuntimeMessage.RetrieverResourceMessage + | None + ): + if message is None: + return None + + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + if isinstance(message, CoreToolInvokeMessage.TextMessage): + return ToolRuntimeMessage.TextMessage(text=message.text) + if isinstance(message, CoreToolInvokeMessage.JsonMessage): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + if isinstance(message, CoreToolInvokeMessage.BlobMessage): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + if isinstance(message, CoreToolInvokeMessage.FileMessage): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + if isinstance(message, CoreToolInvokeMessage.VariableMessage): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + if isinstance(message, CoreToolInvokeMessage.LogMessage): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + + @staticmethod + def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: + if isinstance(exc, ToolNodeError): + return exc + if isinstance(exc, PluginInvokeError): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + if isinstance(exc, PluginDaemonClientSideError): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + if isinstance(exc, ToolInvokeError): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + return ToolRuntimeInvocationError(str(exc)) + + +class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + workflow_execution_id_getter: Callable[[], str | None] | None = None, + form_repository: HumanInputFormRepository | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._workflow_execution_id_getter = workflow_execution_id_getter + self._form_repository = form_repository + + def _invoke_source(self) -> str: + invoke_from = self._run_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + + def _resolve_delivery_methods(self, *, node_data: HumanInputNodeData) -> Sequence[DeliveryChannelConfig]: + invoke_source = self._invoke_source() + methods = [method for method in parse_human_input_delivery_methods(node_data) if method.enabled] + if invoke_source in {"debugger", "explore"}: + methods = [method for method in methods if method.type != DeliveryMethodType.WEBAPP] + return [ + apply_dify_debug_email_recipient( + method, + enabled=invoke_source == "debugger", + actor_id=self._run_context.user_id, + ) + for method in methods + ] + + def _display_in_ui(self, *, node_data: HumanInputNodeData) -> bool: + if self._invoke_source() == "debugger": + return True + return is_human_input_webapp_enabled(node_data) + + def build_form_repository(self) -> HumanInputFormRepository: + if self._form_repository is not None: + return self._form_repository + + return self._build_form_repository() + + def _build_form_repository(self) -> HumanInputFormRepository: + invoke_source = self._invoke_source() + return HumanInputFormRepositoryImpl( + tenant_id=self._run_context.tenant_id, + app_id=self._run_context.app_id, + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + invoke_source=invoke_source, + submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, + ) + + def with_form_repository(self, form_repository: HumanInputFormRepository) -> DifyHumanInputNodeRuntime: + return DifyHumanInputNodeRuntime( + self._run_context, + workflow_execution_id_getter=self._workflow_execution_id_getter, + form_repository=form_repository, + ) + + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + repo = self.build_form_repository() + return repo.get_form(node_id) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: + repo = self.build_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=self._resolve_delivery_methods(node_data=node_data), + display_in_ui=self._display_in_ui(node_data=node_data), + resolved_default_values=resolved_default_values, + ) + return repo.create_form(params) + + +def build_dify_llm_file_saver( + *, + run_context: Mapping[str, Any] | DifyRunContext, + http_client: HttpClientProtocol, + conversation_id_getter: Callable[[], str | None] | None = None, +) -> LLMFileSaver: + from graphon.nodes.llm.file_saver import FileSaverImpl + + return FileSaverImpl( + tool_file_manager=DifyToolFileManager(run_context, conversation_id_getter=conversation_id_getter), + file_reference_factory=DifyFileReferenceFactory(run_context), + http_client=http_client, + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5699ccf404..7b000101b0 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,11 +3,13 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import AgentNodeData from .exceptions import ( @@ -19,8 +21,8 @@ from .runtime_support import AgentRuntimeSupport from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class AgentNode(Node[AgentNodeData]): @@ -59,7 +61,7 @@ class AgentNode(Node[AgentNodeData]): return "1" def populate_start_event(self, event) -> None: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { "name": self.node_data.agent_strategy_name, "icon": self._presentation_provider.get_icon( @@ -71,7 +73,7 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) try: strategy = self._strategy_resolver.resolve( @@ -97,6 +99,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, ) @@ -106,20 +109,21 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, for_log=True, ) credentials = self._runtime_support.build_credentials(parameters=parameters) - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) try: message_stream = strategy.invoke( params=parameters, user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + conversation_id=conversation_id, credentials=credentials, ) except Exception as e: @@ -146,6 +150,7 @@ class AgentNode(Node[AgentNodeData]): parameters_for_log=parameters_for_log, user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + conversation_id=conversation_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 91fed39795..51452c29a3 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f58a5665f4..f44681377d 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -6,27 +6,30 @@ from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.file_access import DatabaseFileAccessController from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( +from extensions.ext_database import db +from factories import file_factory +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( AgentLogEvent, NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent, ) -from dify_graph.variables.segments import ArrayFileSegment -from extensions.ext_database import db -from factories import file_factory +from graphon.variables.segments import ArrayFileSegment from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError +_file_access_controller = DatabaseFileAccessController() + class AgentMessageTransformer: def transform( @@ -37,6 +40,7 @@ class AgentMessageTransformer: parameters_for_log: dict[str, Any], user_id: str, tenant_id: str, + conversation_id: str | None, node_type: NodeType, node_id: str, node_execution_id: str, @@ -47,7 +51,7 @@ class AgentMessageTransformer: messages=messages, user_id=user_id, tenant_id=tenant_id, - conversation_id=None, + conversation_id=conversation_id, ) text = "" @@ -70,10 +74,12 @@ class AgentMessageTransformer: url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -83,20 +89,23 @@ class AgentMessageTransformer: mapping = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } file = file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) @@ -111,6 +120,7 @@ class AgentMessageTransformer: file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9..a872774c98 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -12,16 +12,15 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.plugin.entities.request import InvokeCredentials -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from dify_graph.enums import SystemVariableKey -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated @@ -38,6 +37,7 @@ class AgentRuntimeSupport: node_data: AgentNodeData, strategy: ResolvedAgentStrategy, tenant_id: str, + user_id: str, app_id: str, invoke_from: Any, for_log: bool = False, @@ -141,6 +141,7 @@ class AgentRuntimeSupport: tenant_id, app_id, entity, + user_id, invoke_from, runtime_variable_pool, ) @@ -174,7 +175,11 @@ class AgentRuntimeSupport: value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) - model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + model_instance, model_schema = self.fetch_model( + tenant_id=tenant_id, + user_id=user_id, + value=value, + ) history_prompt_messages = [] if node_data.memory: memory = self.fetch_memory( @@ -219,10 +224,9 @@ class AgentRuntimeSupport: app_id: str, model_instance: ModelInstance, ) -> TokenBufferMemory | None: - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): + conversation_id = get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + if conversation_id is None: return None - conversation_id = conversation_id_variable.value with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) @@ -232,9 +236,15 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( + def fetch_model( + self, + *, + tenant_id: str, + user_id: str, + value: dict[str, Any], + ) -> tuple[ModelInstance, AIModelEntity | None]: + assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id) + provider_model_bundle = assembly.provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM, @@ -246,7 +256,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( + model_instance = assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 44f4a23a5a..38f39b3f94 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,22 +1,25 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -50,15 +53,14 @@ class DatasourceNode(Node[DatasourceNodeData]): """ Run the datasource node """ - - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + datasource_type_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_TYPE) if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None - datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + datasource_info_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_INFO) if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segment.value @@ -131,12 +133,14 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) case DatasourceProviderType.LOCAL_FILE: - related_id = datasource_info.get("related_id") - if not related_id: + file_id = resolve_file_record_id( + datasource_info.get("reference") or datasource_info.get("related_id") + ) + if not file_id: raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=dify_ctx.tenant_id + file_id=file_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 65864474b0..28966f2392 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -3,8 +3,8 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/datasource/protocols.py b/api/core/workflow/nodes/datasource/protocols.py index c006e0885c..776e267317 100644 --- a/api/core/workflow/nodes/datasource/protocols.py +++ b/api/core/workflow/nodes/datasource/protocols.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Any, Protocol -from dify_graph.file import File -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.file import File +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from .entities import DatasourceParameter, OnlineDriveDownloadFileParam diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8d2e9bf3cb..11339bb122 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4ea9091c5b..b465a2d8ff 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -6,12 +6,13 @@ from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, SystemVariableKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template +from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( @@ -19,8 +20,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) _INVOKE_FROM_DEBUGGER = "debugger" @@ -46,21 +47,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # get dataset id as string - dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + dataset_id_segment = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID) if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") dataset_id: str = dataset_id_segment.value # get document id as string (may be empty when not provided) - document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id_segment = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - invoke_from_value = str(invoke_from.value) if invoke_from else None + invoke_from_value = get_system_text(variable_pool, SystemVariableKey.INVOKE_FROM) is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value @@ -87,8 +87,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): outputs=outputs.model_dump(exclude_none=True), ) - original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + original_document_id_segment = get_system_segment(variable_pool, SystemVariableKey.ORIGINAL_DOCUMENT_ID) + batch = get_system_segment(variable_pool, SystemVariableKey.BATCH) if not batch: raise KnowledgeIndexNodeError("Batch is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index bc5618685a..3f7cc364d3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -3,9 +3,9 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 80f59140be..117f426ade 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,26 +9,28 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from core.workflow.file_reference import parse_file_reference +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import NodeRunResult +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.variables import ( ArrayFileSegment, FileSegment, StringSegment, ) -from dify_graph.variables.segments import ArrayObjectSegment +from graphon.variables.segments import ArrayObjectSegment from .entities import ( Condition, @@ -42,8 +44,8 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file.models import File + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -160,7 +162,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -254,7 +256,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_model_config=node_data.metadata_model_config, metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, + attachment_ids=[ + parsed_reference.record_id + for attachment in attachments + if (parsed_reference := parse_file_reference(attachment.reference)) is not None + ] + if attachments + else None, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index e1311ab962..ea45dcf5c2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -3,8 +3,8 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.model_runtime.entities import LLMUsage -from dify_graph.nodes.llm.entities import ModelConfig +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition @@ -54,7 +54,7 @@ class KnowledgeRetrievalRequest(BaseModel): tenant_id: str = Field(description="Tenant unique identifier") user_id: str = Field(description="User unique identifier") app_id: str = Field(description="Application unique identifier") - user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')") + user_from: str = Field(description="User identity source for audit logging (e.g., 'account', 'end-user')") dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from") query: str | None = Field(default=None, description="Query text for knowledge retrieval") retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'") diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index ea7d20befe..23ed2cd408 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -5,8 +5,8 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 118c2f2668..a2c952a899 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -2,11 +2,11 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerEventNodeData @@ -53,13 +53,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]): "plugin_unique_identifier": self.node_data.plugin_unique_identifier, }, } - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 95a2548678..207c1e7253 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -3,8 +3,8 @@ from typing import Literal, Union from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/exc.py b/api/core/workflow/nodes/trigger_schedule/exc.py index 336d64d58f..10962c3de4 100644 --- a/api/core/workflow/nodes/trigger_schedule/exc.py +++ b/api/core/workflow/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b9580e6ab1..dd80617dfc 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData @@ -31,13 +31,11 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): } def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 242bf5ef6a..3125fe17e6 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -4,9 +4,9 @@ from enum import StrEnum from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/exc.py b/api/core/workflow/nodes/trigger_webhook/exc.py index 4d87f2a069..00b0b3baad 100644 --- a/api/core/workflow/nodes/trigger_webhook/exc.py +++ b/api/core/workflow/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 317844cbda..6858d6dc35 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -3,16 +3,17 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.file import FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import FileVariable -from factories import file_factory +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment_with_type +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType +from graphon.file import FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.protocols import FileReferenceFactoryProtocol +from graphon.variables.types import SegmentType +from graphon.variables.variables import FileVariable from .entities import ContentType, WebhookData @@ -23,6 +24,13 @@ class TriggerWebhookNode(Node[WebhookData]): node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT + _file_reference_factory: FileReferenceFactoryProtocol + + def post_init(self) -> None: + from core.workflow.node_runtime import DifyFileReferenceFactory + + self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -53,16 +61,14 @@ class TriggerWebhookNode(Node[WebhookData]): happens in the trigger controller. """ # Get webhook data from variable pool (injected by Celery task) - webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + webhook_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) # Extract webhook-specific outputs based on node configuration outputs = self._extract_configured_outputs(webhook_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + outputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=webhook_inputs, @@ -70,24 +76,20 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - dify_ctx = self.require_dify_context() - related_id = file.get("related_id") + file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: - file["upload_file_id"] = related_id + file["upload_file_id"] = file_id case FileTransferMethod.TOOL_FILE: - file["tool_file_id"] = related_id + file["tool_file_id"] = file_id case FileTransferMethod.DATASOURCE_FILE: - file["datasource_file_id"] = related_id + file["datasource_file_id"] = file_id try: - file_obj = file_factory.build_from_mapping( - mapping=file, - tenant_id=dify_ctx.tenant_id, - ) + file_obj = self._file_reference_factory.build_from_mapping(mapping=file) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) except ValueError: diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py new file mode 100644 index 0000000000..9d15a3fcea --- /dev/null +++ b/api/core/workflow/system_variables.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Protocol, cast +from uuid import uuid4 + +from graphon.enums import BuiltinNodeTypes +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.segments import Segment +from graphon.variables.variables import RAGPipelineVariableInput, Variable + +from .variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) + + +class SystemVariableKey(StrEnum): + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" + TIMESTAMP = "timestamp" + DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" + + +class _VariablePoolReader(Protocol): + def get(self, selector: Sequence[str], /) -> Segment | None: ... + + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ... + + +class _VariablePoolWriter(_VariablePoolReader, Protocol): + def add(self, selector: Sequence[str], value: object, /) -> None: ... + + +class _VariableLoader(Protocol): + def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ... + + +def system_variable_name(key: str | SystemVariableKey) -> str: + return key.value if isinstance(key, SystemVariableKey) else key + + +def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]: + return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key) + + +def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]: + raw_values = dict(values or {}) + raw_values.update(kwargs) + + workflow_execution_id = raw_values.pop("workflow_execution_id", None) + if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values: + raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id + + normalized: dict[str, Any] = {} + for key, value in raw_values.items(): + if value is None: + continue + normalized[system_variable_name(key)] = value + + normalized.setdefault(SystemVariableKey.FILES.value, []) + return normalized + + +def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]: + normalized = _normalize_system_variable_values(values, **kwargs) + + return [ + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, + ), + ) + for key, value in normalized.items() + ] + + +def default_system_variables() -> list[Variable]: + return build_system_variables(workflow_run_id=str(uuid4())) + + +def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]: + return {variable.name: variable.value for variable in system_variables} + + +def _with_selector(variable: Variable, node_id: str) -> Variable: + selector = [node_id, variable.name] + if list(variable.selector) == selector: + return variable + return variable.model_copy(update={"selector": selector}) + + +def build_bootstrap_variables( + *, + system_variables: Sequence[Variable] = (), + environment_variables: Sequence[Variable] = (), + conversation_variables: Sequence[Variable] = (), + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (), +) -> list[Variable]: + variables = [ + *(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables), + *(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables), + *(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables), + ] + + rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_var in rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + rag_pipeline_variables_map[node_id][key] = rag_var.value + + for node_id, value in rag_pipeline_variables_map.items(): + variables.append( + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, + ), + ) + ) + + return variables + + +def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None: + return variable_pool.get(system_variable_selector(key)) + + +def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any: + segment = get_system_segment(variable_pool, key) + return None if segment is None else segment.value + + +def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None: + segment = get_system_segment(variable_pool, key) + if segment is None: + return None + text = getattr(segment, "text", None) + return text if isinstance(text, str) else None + + +def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]: + return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) + + +_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset( + ( + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + ) +) + + +def get_node_creation_preload_selectors( + *, + node_type: str, + node_data: object, +) -> tuple[tuple[str, str], ...]: + """Return selectors that must exist before node construction begins.""" + + if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None: + return () + + return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),) + + +def preload_node_creation_variables( + *, + variable_loader: _VariableLoader, + variable_pool: _VariablePoolWriter, + selectors: Sequence[Sequence[str]], +) -> None: + """Load constructor-time variables before node or graph creation.""" + + seen_selectors: set[tuple[str, ...]] = set() + selectors_to_load: list[list[str]] = [] + for selector in selectors: + normalized_selector = tuple(selector) + if len(normalized_selector) < 2: + raise ValueError(f"Invalid preload selector: {selector}") + if normalized_selector in seen_selectors: + continue + seen_selectors.add(normalized_selector) + if variable_pool.get(normalized_selector) is None: + selectors_to_load.append(list(normalized_selector)) + + loaded_variables = variable_loader.load_variables(selectors_to_load) + for variable in loaded_variables: + raw_selector = getattr(variable, "selector", ()) + loaded_selector = list(raw_selector) + if len(loaded_selector) < 2: + raise ValueError(f"Invalid loaded variable selector: {raw_selector}") + variable_pool.add(loaded_selector[:2], variable) + + +def inject_default_system_variable_mappings( + *, + node_id: str, + node_type: str, + node_data: object, + variable_mapping: Mapping[str, Sequence[str]], +) -> Mapping[str, Sequence[str]]: + """Add workflow-owned implicit sys mappings that `graphon` should not know about.""" + + if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None: + return variable_mapping + + query_mapping_key = f"{node_id}.#sys.query#" + if query_mapping_key in variable_mapping: + return variable_mapping + + augmented_mapping = dict(variable_mapping) + augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY) + return augmented_mapping diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py new file mode 100644 index 0000000000..b4ffb37549 --- /dev/null +++ b/api/core/workflow/template_rendering.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=variables, + ) + except Exception as exc: + if isinstance(exc, CodeExecutionError): + raise TemplateRenderError(str(exc)) from exc + raise + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/variable_pool_initializer.py b/api/core/workflow/variable_pool_initializer.py new file mode 100644 index 0000000000..43523e01b2 --- /dev/null +++ b/api/core/workflow/variable_pool_initializer.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable + + +def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Variable]) -> None: + for variable in variables: + variable_pool.add(variable.selector, variable) + + +def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None: + for key, value in inputs.items(): + variable_pool.add((node_id, key), value) diff --git a/api/dify_graph/constants.py b/api/core/workflow/variable_prefixes.py similarity index 100% rename from api/dify_graph/constants.py rename to api/core/workflow/variable_prefixes.py diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06bab..7429c95c7c 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,36 +1,44 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any from configs import dify_config +from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class -from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file.models import File -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.system_variables import ( + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file.models import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_engine.protocols.command_channel import CommandChannel +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: @@ -59,16 +67,22 @@ class _WorkflowChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + """Build a child engine with a fresh runtime state and only child-safe layers.""" + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, ) + graph_config = graph_init_params.graph_config has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) if has_root_node is False: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -79,17 +93,17 @@ class _WorkflowChildEngineBuilder: root_node_id=root_node_id, ) + command_channel = InMemoryChannel() + config = GraphEngineConfig() child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), + graph_runtime_state=child_graph_runtime_state, + command_channel=command_channel, + config=config, child_engine_builder=self, ) child_engine.layer(LLMQuotaLayer()) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -136,6 +150,8 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + execution_context = capture_current_context() + graph_runtime_state.execution_context = execution_context self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, @@ -212,6 +228,8 @@ class WorkflowEntry: # Get node type node_type = node_config_data.type + node_version = str(node_config_data.version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -226,15 +244,23 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) + + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + + preload_node_creation_variables( + variable_loader=variable_loader, + variable_pool=variable_pool, + selectors=get_node_creation_preload_selectors( + node_type=node_type, + node_data=node_config_data, + ), ) - node = node_factory.create_node(node_config) - node_cls = type(node) try: # variable selector to variable mapping @@ -243,6 +269,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_config_data, + variable_mapping=variable_mapping, + ) # Loading missing variable from draft var here, and set it into # variable_pool. @@ -260,6 +292,13 @@ class WorkflowEntry: tenant_id=workflow.tenant_id, ) + # init workflow run state + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node = node_factory.create_node(node_config) + try: generator = cls._traced_node_run(node) except Exception as e: @@ -347,11 +386,8 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, default_system_variables()) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -366,7 +402,11 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) @@ -384,6 +424,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_data, + variable_mapping=variable_mapping, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -477,13 +523,21 @@ class WorkflowEntry: continue if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: - input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mapping( + mapping=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) if ( isinstance(input_value, list) and all(isinstance(item, dict) for item in input_value) and all("type" in item and "transfer_method" in item for item in input_value) ): - input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mappings( + mappings=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) # append variable and value to variable pool if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: diff --git a/api/core/workflow/workflow_run_outputs.py b/api/core/workflow/workflow_run_outputs.py new file mode 100644 index 0000000000..bd89f7c441 --- /dev/null +++ b/api/core/workflow/workflow_run_outputs.py @@ -0,0 +1,18 @@ +from collections.abc import Mapping +from typing import Any + +from graphon.enums import BuiltinNodeTypes, NodeType + + +def project_node_outputs_for_workflow_run( + *, + node_type: NodeType, + inputs: Mapping[str, Any], + outputs: Mapping[str, Any], +) -> dict[str, Any]: + """Project internal node outputs onto the workflow-run public contract.""" + + if node_type == BuiltinNodeTypes.START: + return dict(inputs) + + return dict(outputs) diff --git a/api/dify_graph/context/__init__.py b/api/dify_graph/context/__init__.py deleted file mode 100644 index 103f526bec..0000000000 --- a/api/dify_graph/context/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Execution Context - Context management for workflow execution. - -This package provides Flask-independent context management for workflow -execution in multi-threaded environments. -""" - -from dify_graph.context.execution_context import ( - AppContext, - ContextProviderNotFoundError, - ExecutionContext, - IExecutionContext, - NullAppContext, - capture_current_context, - read_context, - register_context, - register_context_capturer, - reset_context_provider, -) -from dify_graph.context.models import SandboxContext - -__all__ = [ - "AppContext", - "ContextProviderNotFoundError", - "ExecutionContext", - "IExecutionContext", - "NullAppContext", - "SandboxContext", - "capture_current_context", - "read_context", - "register_context", - "register_context_capturer", - "reset_context_provider", -] diff --git a/api/dify_graph/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py deleted file mode 100644 index 17b19f2502..0000000000 --- a/api/dify_graph/conversation_variable_updater.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Protocol - -from dify_graph.variables import VariableBase - - -class ConversationVariableUpdater(Protocol): - """ - ConversationVariableUpdater defines an abstraction for updating conversation variable values. - - It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating - conversation variables. - - Implementations may choose to batch updates. If batching is used, the `flush` method - should be implemented to persist buffered changes, and `update` - should handle buffering accordingly. - - Note: Since implementations may buffer updates, instances of ConversationVariableUpdater - are not thread-safe. Each VariableAssignerNode should create its own instance during execution. - """ - - @abc.abstractmethod - def update(self, conversation_id: str, variable: "VariableBase"): - """ - Updates the value of the specified conversation variable in the underlying storage. - - :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `VariableBase` instance containing the updated value. - """ - pass - - @abc.abstractmethod - def flush(self): - """ - Flushes all pending updates to the underlying storage system. - - If the implementation does not buffer updates, this method can be a no-op. - """ - pass diff --git a/api/dify_graph/file/constants.py b/api/dify_graph/file/constants.py deleted file mode 100644 index 0665ed7e0d..0000000000 --- a/api/dify_graph/file/constants.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/dify_graph/file/helpers.py b/api/dify_graph/file/helpers.py deleted file mode 100644 index 310cb1310b..0000000000 --- a/api/dify_graph/file/helpers.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -import base64 -import hashlib -import hmac -import os -import time -import urllib.parse - -from .runtime import get_workflow_file_runtime - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url) - url = f"{base_url}/files/{upload_file_id}/file-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign} - if as_attachment: - query["as_attachment"] = "true" - query_string = urllib.parse.urlencode(query) - - return f"{url}?{query_string}" - - -def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - runtime = get_workflow_file_runtime() - # Plugin access should use internal URL for Docker network communication. - base_url = runtime.internal_files_url or runtime.files_url - url = f"{base_url}/files/upload/for-plugin" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) - - -def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str -) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout diff --git a/api/dify_graph/file/protocols.py b/api/dify_graph/file/protocols.py deleted file mode 100644 index 24cbb42735..0000000000 --- a/api/dify_graph/file/protocols.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import Protocol - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``dify_graph.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def files_url(self) -> str: ... - - @property - def internal_files_url(self) -> str | None: ... - - @property - def secret_key(self) -> str: ... - - @property - def files_access_timeout(self) -> int: ... - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... diff --git a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 5fa3d1634b..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,45 +0,0 @@ -import time - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :param user: unique user id - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_moderation( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index e69069a85d..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import IO - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :param user: unique user id - :return: text for given audio file - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_speech_to_text( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index de0677a348..0000000000 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,387 +0,0 @@ -from __future__ import annotations - -import hashlib -import logging -from collections.abc import Sequence -from threading import Lock - -from pydantic import ValidationError -from redis import RedisError - -import contexts -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel -from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) -from extensions.ext_redis import redis_client -from models.provider_ids import ModelProviderID - -logger = logging.getLogger(__name__) - - -class ModelProviderFactory: - def __init__(self, tenant_id: str): - from core.plugin.impl.model import PluginModelClient - - self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelClient() - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers - :return: list of providers - """ - # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server - # The plugin server should return providers in the desired order - plugin_providers = self.get_plugin_model_providers() - return [provider.declaration for provider in plugin_providers] - - def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: - """ - Get all plugin model providers - :return: list of plugin model providers - """ - # check if context is set - try: - contexts.plugin_model_providers.get() - except LookupError: - contexts.plugin_model_providers.set(None) - contexts.plugin_model_providers_lock.set(Lock()) - - with contexts.plugin_model_providers_lock.get(): - plugin_model_providers = contexts.plugin_model_providers.get() - if plugin_model_providers is not None: - return plugin_model_providers - - plugin_model_providers = [] - contexts.plugin_model_providers.set(plugin_model_providers) - - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider - plugin_model_providers.append(provider) - - return plugin_model_providers - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema - :param provider: provider name - :return: provider schema - """ - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - return plugin_model_provider_entity.declaration - - def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: - """ - Get plugin model provider - :param provider: provider name - :return: provider schema - """ - if "/" not in provider: - provider = str(ModelProviderID(provider)) - - # fetch plugin model providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # get the provider - plugin_model_provider_entity = next( - (p for p in plugin_model_provider_entities if p.declaration.provider == provider), - None, - ) - - if not plugin_model_provider_entity: - raise ValueError(f"Invalid provider: {provider}") - - return plugin_model_provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials - - :param provider: provider name - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get provider_credential_schema and validate credentials according to the rules - provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - # validate provider credential schema - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # validate the credentials, raise exception if validation failed - self.plugin_model_manager.validate_provider_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials - - :param provider: provider name - :param model_type: model type - :param model: model name - :param credentials: model credentials, credentials form defined in `model_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get model_credential_schema and validate credentials according to the rules - model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - # validate model credential schema - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # call validate_credentials method of model type to validate credentials, raise exception if validation failed - self.plugin_model_manager.validate_model_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - model_type=model_type.value, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, - ) - - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type - - :param provider: provider name - :param model_type: model type - :param provider_configs: list of provider configs - :return: list of models - """ - provider_configs = provider_configs or [] - - # scan all providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # traverse all model_provider_extensions - providers = [] - for plugin_model_provider_entity in plugin_model_provider_entities: - # filter by provider if provider is present - if provider and plugin_model_provider_entity.declaration.provider != provider: - continue - - # get provider schema - provider_schema = plugin_model_provider_entity.declaration - - model_types = provider_schema.supported_model_types - if model_type: - if model_type not in model_types: - continue - - model_types = [model_type] - - all_model_type_models = [] - for model_schema in provider_schema.models: - if model_schema.model_type != model_type: - continue - - all_model_type_models.append(model_schema) - - simple_provider_schema = provider_schema.to_simple_provider() - if model_type: - simple_provider_schema.models = all_model_type_models - - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type - :param provider: provider name - :param model_type: model type - :return: model type instance - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - init_params = { - "tenant_id": self.tenant_id, - "plugin_id": plugin_id, - "provider_name": provider_name, - "plugin_model_provider": self.get_plugin_model_provider(provider), - } - - if model_type == ModelType.LLM: - return LargeLanguageModel.model_validate(init_params) - elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel.model_validate(init_params) - elif model_type == ModelType.RERANK: - return RerankModel.model_validate(init_params) - elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel.model_validate(init_params) - elif model_type == ModelType.MODERATION: - return ModerationModel.model_validate(init_params) - elif model_type == ModelType.TTS: - return TTSModel.model_validate(init_params) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon - :param provider: provider name - :param icon_type: icon type (icon_small or icon_small_dark) - :param lang: language (zh_Hans or en_US) - :return: provider icon - """ - # get the provider schema - provider_schema = self.get_provider_schema(provider) - - if icon_type.lower() == "icon_small": - if not provider_schema.icon_small: - raise ValueError(f"Provider {provider} does not have small icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small.zh_Hans - else: - file_name = provider_schema.icon_small.en_US - elif icon_type.lower() == "icon_small_dark": - if not provider_schema.icon_small_dark: - raise ValueError(f"Provider {provider} does not have small dark icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small_dark.zh_Hans - else: - file_name = provider_schema.icon_small_dark.en_US - else: - raise ValueError(f"Unsupported icon type: {icon_type}.") - - if not file_name: - raise ValueError(f"Provider {provider} does not have icon.") - - image_mime_types = { - "jpg": "image/jpeg", - "jpeg": "image/jpeg", - "png": "image/png", - "gif": "image/gif", - "bmp": "image/bmp", - "tiff": "image/tiff", - "tif": "image/tiff", - "webp": "image/webp", - "svg": "image/svg+xml", - "ico": "image/vnd.microsoft.icon", - "heif": "image/heif", - "heic": "image/heic", - } - - extension = file_name.split(".")[-1] - mime_type = image_mime_types.get(extension, "image/png") - - # get icon bytes from plugin asset manager - from core.plugin.impl.asset import PluginAssetManager - - plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type - - def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: - """ - Get plugin id and provider name from provider name - :param provider: provider name - :return: plugin id and provider name - """ - - provider_id = ModelProviderID(provider) - return provider_id.plugin_id, provider_id.provider_name diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py deleted file mode 100644 index 0223149bb8..0000000000 --- a/api/dify_graph/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py deleted file mode 100644 index 2a33b4a0a8..0000000000 --- a/api/dify_graph/nodes/human_input/entities.py +++ /dev/null @@ -1,424 +0,0 @@ -""" -Human Input node entities. -""" - -import re -import uuid -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self - -import bleach -import markdown -from pydantic import BaseModel, Field, field_validator, model_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool -from dify_graph.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") - _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ - "a", - "blockquote", - "br", - "code", - "em", - "h1", - "h2", - "h3", - "h4", - "h5", - "h6", - "hr", - "li", - "ol", - "p", - "pre", - "strong", - "table", - "tbody", - "td", - "th", - "thead", - "tr", - "ul", - ] - _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { - "a": ["href", "title"], - "td": ["align"], - "th": ["align"], - } - _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": - if user_id is None: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - @classmethod - def render_markdown_body(cls, body: str) -> str: - """Render markdown to safe HTML for email delivery.""" - sanitized_markdown = bleach.clean( - body, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - rendered_html = markdown.markdown( - sanitized_markdown, - extensions=["nl2br", "tables"], - extension_configs={"tables": {"use_align_attribute": True}}, - ) - return bleach.clean( - rendered_html, - tags=cls._ALLOWED_HTML_TAGS, - attributes=cls._ALLOWED_HTML_ATTRIBUTES, - protocols=cls._ALLOWED_PROTOCOLS, - strip=True, - strip_comments=True, - ) - - @classmethod - def sanitize_subject(cls, subject: str) -> str: - """Sanitize email subject to plain text and prevent CRLF injection.""" - sanitized_subject = bleach.clean( - subject, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) - return " ".join(sanitized_subject.split()) - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str | None, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id) - return method.model_copy(update={"config": debug_config}) - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py deleted file mode 100644 index 9e95d341c9..0000000000 --- a/api/dify_graph/nodes/llm/protocols.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from core.model_manager import ModelInstance - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating initialized LLM model instances for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: - """Create a model instance that is ready for schema lookup and invocation.""" - ... - - -class TemplateRenderer(Protocol): - """Port for rendering prompt templates used by LLM-compatible nodes.""" - - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - """Render the given Jinja2 template into plain text.""" - ... diff --git a/api/dify_graph/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py deleted file mode 100644 index 9b679d4497..0000000000 --- a/api/dify_graph/nodes/template_transform/template_renderer.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage - - -class TemplateRenderError(ValueError): - """Raised when rendering a Jinja2 template fails.""" - - -class Jinja2TemplateRenderer(Protocol): - """Render Jinja2 templates for template transform nodes.""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render a Jinja2 template with provided variables.""" - raise NotImplementedError - - -class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via CodeExecutor.""" - - _code_executor: WorkflowCodeExecutor - - def __init__(self, code_executor: WorkflowCodeExecutor) -> None: - self._code_executor = code_executor - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - try: - result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) - except Exception as exc: - if self._code_executor.is_execution_error(exc): - raise TemplateRenderError(str(exc)) from exc - raise - - rendered = result.get("result") - if not isinstance(rendered, str): - raise TemplateRenderError("Template render result must be a string.") - return rendered diff --git a/api/dify_graph/repositories/__init__.py b/api/dify_graph/repositories/__init__.py deleted file mode 100644 index ef70eb09cc..0000000000 --- a/api/dify_graph/repositories/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Repository interfaces for data access. - -This package contains repository interfaces that define the contract -for accessing and manipulating data, regardless of the underlying -storage mechanism. -""" - -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository - -__all__ = [ - "OrderConfig", - "WorkflowNodeExecutionRepository", -] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py deleted file mode 100644 index 88966831cb..0000000000 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/dify_graph/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py deleted file mode 100644 index ef83f07649..0000000000 --- a/api/dify_graph/repositories/workflow_execution_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Protocol - -from dify_graph.entities import WorkflowExecution - - -class WorkflowExecutionRepository(Protocol): - """ - Repository interface for WorkflowExecution. - - This interface defines the contract for accessing and manipulating - WorkflowExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowExecution): - """ - Save or update a WorkflowExecution instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The WorkflowExecution instance to save or update - """ - ... diff --git a/api/dify_graph/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py deleted file mode 100644 index e6c1c3e497..0000000000 --- a/api/dify_graph/repositories/workflow_node_execution_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Literal, Protocol - -from dify_graph.entities import WorkflowNodeExecution - - -@dataclass -class OrderConfig: - """Configuration for ordering NodeExecution instances.""" - - order_by: list[str] - order_direction: Literal["asc", "desc"] | None = None - - -class WorkflowNodeExecutionRepository(Protocol): - """ - Repository interface for NodeExecution. - - This interface defines the contract for accessing and manipulating - NodeExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and trigger sources (triggered_from) should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowNodeExecution): - """ - Save or update a NodeExecution instance. - - This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`, - and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time - and execution-related details. - - It's main purpose is to save the status and various metadata (execution time, execution metadata etc.) - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The NodeExecution instance to save or update - """ - ... - - def save_execution_data(self, execution: WorkflowNodeExecution): - """Save or update the inputs, process_data, or outputs associated with a specific - node_execution record. - - If any of the inputs, process_data, or outputs are None, those fields will not be updated. - """ - ... - - def get_by_workflow_run( - self, - workflow_run_id: str, - order_config: OrderConfig | None = None, - ) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - order_config: Optional configuration for ordering results - order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) - order_config.order_direction: Direction to order ("asc" or "desc") - - Returns: - A list of NodeExecution instances - """ - ... diff --git a/api/dify_graph/system_variable.py b/api/dify_graph/system_variable.py deleted file mode 100644 index cc5deda892..0000000000 --- a/api/dify_graph/system_variable.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from types import MappingProxyType -from typing import Any -from uuid import uuid4 - -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator - -from dify_graph.enums import SystemVariableKey -from dify_graph.file.models import File - - -class SystemVariable(BaseModel): - """A model for managing system variables. - - Fields with a value of `None` are treated as absent and will not be included - in the variable pool. - """ - - model_config = ConfigDict( - extra="forbid", - serialize_by_alias=True, - validate_by_alias=True, - ) - - user_id: str | None = None - - # Ideally, `app_id` and `workflow_id` should be required and not `None`. - # However, there are scenarios in the codebase where these fields are not set. - # To maintain compatibility, they are marked as optional here. - app_id: str | None = None - workflow_id: str | None = None - - timestamp: int | None = None - - files: Sequence[File] = Field(default_factory=list) - - # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. - # To maintain compatibility with existing workflows, it must be serialized - # as `workflow_run_id` in dictionaries or JSON objects, and also referenced - # as `workflow_run_id` in the variable pool. - workflow_execution_id: str | None = Field( - validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), - serialization_alias="workflow_run_id", - default=None, - ) - # Chatflow related fields. - query: str | None = None - conversation_id: str | None = None - dialogue_count: int | None = None - document_id: str | None = None - original_document_id: str | None = None - dataset_id: str | None = None - batch: str | None = None - datasource_type: str | None = None - datasource_info: Mapping[str, Any] | None = None - invoke_from: str | None = None - - @model_validator(mode="before") - @classmethod - def validate_json_fields(cls, data): - if isinstance(data, dict): - # For JSON validation, only allow workflow_run_id - if "workflow_execution_id" in data and "workflow_run_id" not in data: - # This is likely from direct instantiation, allow it - return data - elif "workflow_execution_id" in data and "workflow_run_id" in data: - # Both present, remove workflow_execution_id - data = data.copy() - data.pop("workflow_execution_id") - return data - return data - - @classmethod - def default(cls) -> SystemVariable: - return cls(workflow_execution_id=str(uuid4())) - - def to_dict(self) -> dict[SystemVariableKey, Any]: - # NOTE: This method is provided for compatibility with legacy code. - # New code should use the `SystemVariable` object directly instead of converting - # it to a dictionary, as this conversion results in the loss of type information - # for each key, making static analysis more difficult. - - d: dict[SystemVariableKey, Any] = { - SystemVariableKey.FILES: self.files, - } - if self.user_id is not None: - d[SystemVariableKey.USER_ID] = self.user_id - if self.app_id is not None: - d[SystemVariableKey.APP_ID] = self.app_id - if self.workflow_id is not None: - d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id - if self.workflow_execution_id is not None: - d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id - if self.query is not None: - d[SystemVariableKey.QUERY] = self.query - if self.conversation_id is not None: - d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id - if self.dialogue_count is not None: - d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count - if self.document_id is not None: - d[SystemVariableKey.DOCUMENT_ID] = self.document_id - if self.original_document_id is not None: - d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id - if self.dataset_id is not None: - d[SystemVariableKey.DATASET_ID] = self.dataset_id - if self.batch is not None: - d[SystemVariableKey.BATCH] = self.batch - if self.datasource_type is not None: - d[SystemVariableKey.DATASOURCE_TYPE] = self.datasource_type - if self.datasource_info is not None: - d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info - if self.invoke_from is not None: - d[SystemVariableKey.INVOKE_FROM] = self.invoke_from - if self.timestamp is not None: - d[SystemVariableKey.TIMESTAMP] = self.timestamp - return d - - def as_view(self) -> SystemVariableReadOnlyView: - return SystemVariableReadOnlyView(self) - - -class SystemVariableReadOnlyView: - """ - A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. - - This class wraps a SystemVariable instance and provides read-only access to all its fields. - It always reads the latest data from the wrapped instance and prevents any write operations. - """ - - def __init__(self, system_variable: SystemVariable) -> None: - """ - Initialize the read-only view with a SystemVariable instance. - - Args: - system_variable: The SystemVariable instance to wrap - """ - self._system_variable = system_variable - - @property - def user_id(self) -> str | None: - return self._system_variable.user_id - - @property - def app_id(self) -> str | None: - return self._system_variable.app_id - - @property - def workflow_id(self) -> str | None: - return self._system_variable.workflow_id - - @property - def workflow_execution_id(self) -> str | None: - return self._system_variable.workflow_execution_id - - @property - def query(self) -> str | None: - return self._system_variable.query - - @property - def conversation_id(self) -> str | None: - return self._system_variable.conversation_id - - @property - def dialogue_count(self) -> int | None: - return self._system_variable.dialogue_count - - @property - def document_id(self) -> str | None: - return self._system_variable.document_id - - @property - def original_document_id(self) -> str | None: - return self._system_variable.original_document_id - - @property - def dataset_id(self) -> str | None: - return self._system_variable.dataset_id - - @property - def batch(self) -> str | None: - return self._system_variable.batch - - @property - def datasource_type(self) -> str | None: - return self._system_variable.datasource_type - - @property - def invoke_from(self) -> str | None: - return self._system_variable.invoke_from - - @property - def files(self) -> Sequence[File]: - """ - Get a copy of the files from the wrapped SystemVariable. - - Returns: - A defensive copy of the files sequence to prevent modification - """ - return tuple(self._system_variable.files) # Convert to immutable tuple - - @property - def datasource_info(self) -> Mapping[str, Any] | None: - """ - Get a copy of the datasource info from the wrapped SystemVariable. - - Returns: - A view of the datasource info mapping to prevent modification - """ - if self._system_variable.datasource_info is None: - return None - return MappingProxyType(self._system_variable.datasource_info) - - def __repr__(self) -> str: - """Return a string representation of the read-only view.""" - return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/dify_graph/__init__.py b/api/enterprise/__init__.py similarity index 100% rename from api/dify_graph/__init__.py rename to api/enterprise/__init__.py diff --git a/api/enterprise/telemetry/DATA_DICTIONARY.md b/api/enterprise/telemetry/DATA_DICTIONARY.md new file mode 100644 index 0000000000..60d482cd1c --- /dev/null +++ b/api/enterprise/telemetry/DATA_DICTIONARY.md @@ -0,0 +1,525 @@ +# Dify Enterprise Telemetry Data Dictionary + +Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md). + +## Resource Attributes + +Attached to every signal (Span, Metric, Log). + +| Attribute | Type | Example | +|-----------|------|---------| +| `service.name` | string | `dify` | +| `host.name` | string | `dify-api-7f8b` | + +## Traces (Spans) + +### `dify.workflow.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID (Workflow Run ID) | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Unique ID for this run | +| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. | +| `dify.workflow.error` | string | Error message if failed | +| `dify.workflow.elapsed_time` | float | Total execution time (seconds) | +| `dify.invoke_from` | string | `api`, `webapp`, `debug` | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.message.id` | string | Message ID (optional) | +| `dify.invoked_by` | string | User ID who triggered the run | +| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) | +| `gen_ai.user.id` | string | End-user identifier (optional) | +| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) | +| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) | +| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) | +| `dify.parent.app.id` | string | Parent app ID (optional) | + +### `dify.node.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Workflow Run ID | +| `dify.message.id` | string | Message ID (optional) | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.node.execution_id` | string | Unique node execution ID | +| `dify.node.id` | string | Node ID in workflow graph | +| `dify.node.type` | string | Node type (see appendix) | +| `dify.node.title` | string | Display title | +| `dify.node.status` | string | `succeeded`, `failed` | +| `dify.node.error` | string | Error message if failed | +| `dify.node.elapsed_time` | float | Execution time (seconds) | +| `dify.node.index` | int | Execution order index | +| `dify.node.predecessor_node_id` | string | Triggering node ID | +| `dify.node.iteration_id` | string | Iteration ID (optional) | +| `dify.node.loop_id` | string | Loop ID (optional) | +| `dify.node.parallel_id` | string | Parallel branch ID (optional) | +| `dify.node.invoked_by` | string | User ID who triggered execution | +| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) | +| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) | +| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) | +| `gen_ai.request.model` | string | LLM model name (LLM nodes only) | +| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) | +| `gen_ai.user.id` | string | End-user identifier (optional) | + +### `dify.node.execution.draft` + +Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs. + +## Counters + +All counters are cumulative and emitted at 100% accuracy. + +### Token Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.tokens.total` | `{token}` | Total tokens consumed | +| `dify.tokens.input` | `{token}` | Input (prompt) tokens | +| `dify.tokens.output` | `{token}` | Output (completion) tokens | + +**Labels:** + +- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution) + +⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting. + +#### Token Hierarchy & Query Patterns + +Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting: + +``` +App-level total +├── workflow ← sum of all node_execution tokens (DO NOT add both) +│ └── node_execution ← per-node breakdown +├── message ← independent (non-workflow chat apps only) +├── rule_generate ← independent helper LLM call +├── code_generate ← independent helper LLM call +├── structured_output ← independent helper LLM call +└── instruction_modify← independent helper LLM call +``` + +**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both. + +**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`. +App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries. + +**Common queries** (PromQL): + +```promql +# ── Totals ────────────────────────────────────────────────── +# App-level total (exclude node_execution to avoid double-counting) +sum by (app_id) (dify_tokens_total{operation_type!="node_execution"}) + +# Single app total +sum (dify_tokens_total{app_id="", operation_type!="node_execution"}) + +# Per-tenant totals +sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"}) + +# ── Drill-down ────────────────────────────────────────────── +# Workflow-level tokens for an app +sum (dify_tokens_total{app_id="", operation_type="workflow"}) + +# Node-level breakdown within an app +sum by (node_type) (dify_tokens_total{app_id="", operation_type="node_execution"}) + +# Model breakdown for an app +sum by (model_provider, model_name) (dify_tokens_total{app_id=""}) + +# Input vs output per model +sum by (model_name) (dify_tokens_input_total{app_id=""}) +sum by (model_name) (dify_tokens_output_total{app_id=""}) + +# ── Rates ─────────────────────────────────────────────────── +# Token consumption rate (per hour) +sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h])) + +# Per-app consumption rate +sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h])) +``` + +**Finding `app_id` from app name** (trace query — Tempo / Jaeger): + +``` +{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id) +``` + +### Request Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.requests.total` | `{request}` | Total operations count | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `moderation` | `tenant_id`, `app_id` | +| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dataset_retrieval` | `tenant_id`, `app_id` | +| `generate_name` | `tenant_id`, `app_id` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` | + +### Error Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.errors.total` | `{error}` | Total failed operations | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +### Other Counters + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` | +| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` | +| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` | +| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` | +| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` | + +## Histograms + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` | +| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` | +| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` | +| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +## Structured Logs + +### Span Companion Logs + +Logs that accompany spans. Signal type: `span_detail` + +#### `dify.workflow.run` Companion Log + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.workflow.version` | string | Yes | Workflow definition version | +| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) | +| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) | +| `dify.workflow.query` | string | No | User query text (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.workflow.run"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.invoke_from` | string | No | Invocation source | +| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) | +| `dify.node.total_price` | float | No | Cost (LLM nodes only) | +| `dify.node.currency` | string | No | Currency code (LLM nodes only) | +| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) | +| `dify.node.loop_index` | int | No | Loop index (loop nodes) | +| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) | +| `dify.credential.name` | string | No | Credential name (plugin nodes) | +| `dify.credential.id` | string | No | Credential ID (plugin nodes) | +| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) | +| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) | +| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) | +| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) | +| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +### Standalone Logs + +Logs without structural spans. Signal type: `metric_only` + +#### `dify.message.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.message.run"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID (32-char hex) | +| `span_id` | string | OTEL span ID (16-char hex) | +| `tenant_id` | string | Tenant identifier | +| `user_id` | string | User identifier (optional) | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.workflow.run_id` | string | Workflow run ID (optional) | +| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.message.status` | string | `succeeded`, `failed` | +| `dify.message.error` | string | Error message (if failed) | +| `dify.message.duration` | float | Duration (seconds) | +| `dify.message.time_to_first_token` | float | TTFT (seconds) | +| `dify.message.inputs` | string/JSON | Inputs (content-gated) | +| `dify.message.outputs` | string/JSON | Outputs (content-gated) | + +#### `dify.tool.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.tool.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.tool.name` | string | Tool name | +| `dify.tool.duration` | float | Duration (seconds) | +| `dify.tool.status` | string | `succeeded`, `failed` | +| `dify.tool.error` | string | Error message (if failed) | +| `dify.tool.inputs` | string/JSON | Inputs (content-gated) | +| `dify.tool.outputs` | string/JSON | Outputs (content-gated) | +| `dify.tool.parameters` | string/JSON | Parameters (content-gated) | +| `dify.tool.config` | string/JSON | Configuration (content-gated) | + +#### `dify.moderation.check` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.moderation.check"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.moderation.type` | string | `input`, `output` | +| `dify.moderation.action` | string | `pass`, `block`, `flag` | +| `dify.moderation.flagged` | boolean | Whether flagged | +| `dify.moderation.categories` | JSON array | Flagged categories | +| `dify.moderation.query` | string | Content (content-gated) | + +#### `dify.suggested_question.generation` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.suggested_question.generation"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.suggested_question.count` | int | Number of questions | +| `dify.suggested_question.duration` | float | Duration (seconds) | +| `dify.suggested_question.status` | string | `succeeded`, `failed` | +| `dify.suggested_question.error` | string | Error message (if failed) | +| `dify.suggested_question.questions` | JSON array | Questions (content-gated) | + +#### `dify.dataset.retrieval` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.dataset.retrieval"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.dataset.id` | string | Dataset identifier | +| `dify.dataset.name` | string | Dataset name | +| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) | +| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) | +| `dify.retrieval.rerank_provider` | string | Rerank model provider | +| `dify.retrieval.rerank_model` | string | Rerank model name | +| `dify.retrieval.query` | string | Search query (content-gated) | +| `dify.retrieval.document_count` | int | Documents retrieved | +| `dify.retrieval.duration` | float | Duration (seconds) | +| `dify.retrieval.status` | string | `succeeded`, `failed` | +| `dify.retrieval.error` | string | Error message (if failed) | +| `dify.dataset.documents` | JSON array | Documents (content-gated) | + +#### `dify.generate_name.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.generate_name.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.conversation.id` | string | Conversation identifier | +| `dify.generate_name.duration` | float | Duration (seconds) | +| `dify.generate_name.status` | string | `succeeded`, `failed` | +| `dify.generate_name.error` | string | Error message (if failed) | +| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) | +| `dify.generate_name.outputs` | string | Generated name (content-gated) | + +#### `dify.prompt_generation.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.prompt_generation.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.prompt_generation.duration` | float | Duration (seconds) | +| `dify.prompt_generation.status` | string | `succeeded`, `failed` | +| `dify.prompt_generation.error` | string | Error message (if failed) | +| `dify.prompt_generation.instruction` | string | Instruction (content-gated) | +| `dify.prompt_generation.output` | string/JSON | Output (content-gated) | + +#### `dify.app.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` | +| `dify.app.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.updated` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.updated"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.updated_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.deleted` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.deleted"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.deleted_at` | string | Timestamp (ISO 8601) | + +#### `dify.feedback.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.feedback.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.feedback.rating` | string | `like`, `dislike`, `null` | +| `dify.feedback.content` | string | Feedback text (content-gated) | +| `dify.feedback.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.telemetry.rehydration_failed` + +Diagnostic event for telemetry system health monitoring. + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.telemetry.error` | string | Error message | +| `dify.telemetry.payload_type` | string | Payload type (see appendix) | +| `dify.telemetry.correlation_id` | string | Correlation ID | + +## Content-Gated Attributes + +When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`). + +| Attribute | Signal | +|-----------|--------| +| `dify.workflow.inputs` | `dify.workflow.run` | +| `dify.workflow.outputs` | `dify.workflow.run` | +| `dify.workflow.query` | `dify.workflow.run` | +| `dify.node.inputs` | `dify.node.execution` | +| `dify.node.outputs` | `dify.node.execution` | +| `dify.node.process_data` | `dify.node.execution` | +| `dify.message.inputs` | `dify.message.run` | +| `dify.message.outputs` | `dify.message.run` | +| `dify.tool.inputs` | `dify.tool.execution` | +| `dify.tool.outputs` | `dify.tool.execution` | +| `dify.tool.parameters` | `dify.tool.execution` | +| `dify.tool.config` | `dify.tool.execution` | +| `dify.moderation.query` | `dify.moderation.check` | +| `dify.suggested_question.questions` | `dify.suggested_question.generation` | +| `dify.retrieval.query` | `dify.dataset.retrieval` | +| `dify.dataset.documents` | `dify.dataset.retrieval` | +| `dify.generate_name.inputs` | `dify.generate_name.execution` | +| `dify.generate_name.outputs` | `dify.generate_name.execution` | +| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` | +| `dify.prompt_generation.output` | `dify.prompt_generation.execution` | +| `dify.feedback.content` | `dify.feedback.created` | + +## Appendix + +### Operation Types + +- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify` + +### Node Types + +- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input` + +### Workflow Statuses + +- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused` + +### Payload Types + +- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback` + +### Null Value Behavior + +**Spans:** Attributes with `null` values are omitted. + +**Logs:** Attributes with `null` values appear as `null` in JSON. + +**Content-Gated:** Replaced with reference strings, not set to `null`. diff --git a/api/enterprise/telemetry/README.md b/api/enterprise/telemetry/README.md new file mode 100644 index 0000000000..e43c0b1ea2 --- /dev/null +++ b/api/enterprise/telemetry/README.md @@ -0,0 +1,121 @@ +# Dify Enterprise Telemetry + +This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb. + +## Overview + +Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage. + +- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes). +- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`. +- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking. + +### Signal Architecture + +```mermaid +graph TD + A[Workflow Run] -->|Span| B(dify.workflow.run) + A -->|Log| C(dify.workflow.run detail) + B ---|trace_id| C + + D[Node Execution] -->|Span| E(dify.node.execution) + D -->|Log| F(dify.node.execution detail) + E ---|span_id| F + + G[Message/Tool/etc] -->|Log| H(dify.* event) + G -->|Metric| I(dify.* counter/histogram) +``` + +## Configuration + +The Enterprise OTEL exporter is configured via environment variables. + +| Variable | Description | Default | +|----------|-------------|---------| +| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` | +| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` | +| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - | +| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - | +| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` | +| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - | +| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` | +| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` | +| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` | + +## Correlation Model + +Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks. + +### ID Generation Rules + +- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))` +- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)` + +### Scenario A: Simple Workflow + +A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`). + +``` +trace_id = UUID(workflow_run_id) +├── [root span] dify.workflow.run (span_id = hash(workflow_run_id)) +│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1)) +│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2)) +│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3)) +``` + +### Scenario B: Nested Sub-Workflow + +A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id. + +``` +trace_id = UUID(outer_workflow_run_id) ← shared across both workflows +├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id)) +│ ├── dify.node.execution - "Start Node" +│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow) +│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id)) +│ │ ├── dify.node.execution - "Inner Start" +│ │ └── dify.node.execution - "Inner End" +│ └── dify.node.execution - "End Node" +``` + +**Key attributes for nested workflows:** + +- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id` +- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.app.id` = outer `app_id` + +### Scenario C: Draft Node Execution + +A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root. + +``` +trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow +└── dify.node.execution.draft (span_id = hash(node_execution_id)) +``` + +**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace. + +## Content Gating + +When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector. + +**Reference String Format:** + +``` +ref:{id_type}={uuid} +``` + +**Examples:** + +``` +ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000 +ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001 +ref:message_id=770e8400-e29b-41d4-a716-446655440002 +``` + +To retrieve actual content when gating is enabled, query the Dify database using the provided UUID. + +## Reference + +For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md). diff --git a/api/dify_graph/graph_engine/entities/__init__.py b/api/enterprise/telemetry/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/entities/__init__.py rename to api/enterprise/telemetry/__init__.py diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 0000000000..91398cb8cb --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,73 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload (inline for small payloads, + empty when offloaded to storage via ``payload_ref``). + metadata: Optional metadata dictionary. When the gateway + offloads a large payload to object storage, this contains + ``{"payload_ref": ""}``. + """ + + model_config = ConfigDict(extra="forbid", use_enum_values=False) + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + metadata: dict[str, Any] | None = None diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 0000000000..dff558988c --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from graphon.enums import WorkflowNodeExecutionMetadataKey +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + process_data = execution.process_data_dict or {} + + # Extract token breakdown from outputs.usage (set by LLM node) + usage: Mapping[str, Any] = {} + if isinstance(node_outputs, Mapping): + raw_usage = node_outputs.get("usage") + if isinstance(raw_usage, Mapping): + usage = raw_usage + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "model_provider": process_data.get("model_provider"), + "model_name": process_data.get("model_name"), + "prompt_tokens": usage.get("prompt_tokens"), + "completion_tokens": usage.get("completion_tokens"), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 0000000000..fc17d9d93e --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,966 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. + +Token metric labels (unified structure): +All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the +same label set for consistent filtering and aggregation: +- tenant_id: Tenant identifier +- app_id: Application identifier +- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.) +- model_provider: LLM provider name (empty string if not applicable) +- model_name: LLM model name (empty string if not applicable) +- node_type: Workflow node type (empty string if not node_execution) + +This unified structure allows filtering by operation_type to separate: +- Workflow-level aggregates (operation_type=workflow) +- Individual node executions (operation_type=node_execution) +- Direct message calls (operation_type=message) +- Prompt generation operations (operation_type=rule_generate, code_generate, etc.) + +Without this, tokens are double-counted when querying totals (workflow totals include +node totals, since workflow.total_tokens is the sum of all node tokens). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + OperationType, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, + TokenMetricLabels, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + else: + raise AssertionError("this statment should be unreachable") + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.user.id": user_id, + } + + trace_correlation_override, parent_span_id_source = info.resolved_parent_context + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + # Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults + # to 0 in the DB and can be stale if the Celery write races with the trace task. + # start_time = workflow_run.created_at, end_time = workflow_run.finished_at. + if info.start_time and info.end_time: + workflow_duration = (info.end_time - info.start_time).total_seconds() + elif info.workflow_run_elapsed_time: + workflow_duration = float(info.workflow_run_elapsed_time) + else: + workflow_duration = 0.0 + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + workflow_duration, + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.request.model": info.model_name, + "gen_ai.provider.name": info.model_provider, + "gen_ai.user.id": user_id, + } + + resolved_override, _ = info.resolved_parent_context + trace_correlation_override = trace_correlation_override_param or resolved_override + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.credential.id": metadata.get("credential_id"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.NODE_EXECUTION, + model_provider=info.model_provider or "", + model_name=info.model_name or "", + node_type=info.node_type, + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + model_name=info.model_name or "", + ), + ) + duration_labels = dict(labels) + duration_labels["model_name"] = info.model_name or "" + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + model_name=info.model_name or "", + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + + if info.start_time and info.end_time: + attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds() + + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MESSAGE_RUN, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.MESSAGE, + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.message_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels) + if info.answer_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.tool.name": info.tool_name, + "dify.tool.duration": float(info.time_cost), + "dify.tool.status": "failed" if info.error else "succeeded", + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.moderation.type": metadata.get("moderation_type", "input"), + "dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MODERATION_CHECK, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error = info.error or (info.metadata.get("error") if info.metadata else None) + status = "failed" if error else (info.status or "succeeded") + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": status, + "dify.suggested_question.error": error, + "dify.suggested_question.duration": duration, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + model_provider=info.model_provider or "", + model_name=info.model_id or "", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.retrieval.error"] = info.error + attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded" + if info.start_time and info.end_time: + attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds() + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.id"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.name"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + # Add rerank model to logs + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + if rerank_provider or rerank_model: + attrs["dify.retrieval.rerank_provider"] = rerank_provider + attrs["dify.retrieval.rerank_model"] = rerank_model + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + # Get embedding model for this specific dataset + ds_embedding_info = embedding_models.get(did, {}) + embedding_provider = ds_embedding_info.get("embedding_model_provider", "") + embedding_model = ds_embedding_info.get("embedding_model", "") + + # Get rerank model (same for all datasets in this retrieval) + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + embedding_model_provider=embedding_provider, + embedding_model=embedding_model, + rerank_model_provider=rerank_provider, + rerank_model=rerank_model, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error: str | None = metadata.get("error") if metadata else None + status = "failed" if error else "succeeded" + attrs["dify.generate_name.duration"] = duration + attrs["dify.generate_name.status"] = status + attrs["dify.generate_name.error"] = error + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "gen_ai.user.id": user_id, + "dify.app_id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.prompt_generation.operation_type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.duration": info.latency, + "dify.prompt_generation.status": "failed" if info.error else "succeeded", + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + node_type="", + ).to_dict() + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + + prompt_status = "failed" if info.error else "succeeded" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=prompt_status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 0000000000..4a9bd3dbf8 --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,121 @@ +from enum import StrEnum +from typing import cast + +from opentelemetry.util.types import AttributeValue +from pydantic import BaseModel, ConfigDict + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryEvent(StrEnum): + """Event names for enterprise telemetry logs.""" + + APP_CREATED = "dify.app.created" + APP_UPDATED = "dify.app.updated" + APP_DELETED = "dify.app.deleted" + FEEDBACK_CREATED = "dify.feedback.created" + WORKFLOW_RUN = "dify.workflow.run" + MESSAGE_RUN = "dify.message.run" + TOOL_EXECUTION = "dify.tool.execution" + MODERATION_CHECK = "dify.moderation.check" + SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation" + DATASET_RETRIEVAL = "dify.dataset.retrieval" + GENERATE_NAME_EXECUTION = "dify.generate_name.execution" + PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution" + REHYDRATION_FAILED = "dify.telemetry.rehydration_failed" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +class TokenMetricLabels(BaseModel): + """Unified label structure for all dify.token.* metrics. + + All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST + use this exact label set to ensure consistent filtering and aggregation across + different operation types. + + Attributes: + tenant_id: Tenant identifier. + app_id: Application identifier. + operation_type: Source of token usage (workflow | node_execution | message | + rule_generate | code_generate | structured_output | instruction_modify). + model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level). + model_name: LLM model name. Empty string if not applicable (e.g., workflow-level). + node_type: Workflow node type. Empty string unless operation_type=node_execution. + + Usage: + labels = TokenMetricLabels( + tenant_id="tenant-123", + app_id="app-456", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, + 100, + labels.to_dict() + ) + + Design rationale: + Without this unified structure, tokens get double-counted when querying totals + because workflow.total_tokens is already the sum of all node tokens. The + operation_type label allows filtering to separate workflow-level aggregates from + node-level detail, while keeping the same label cardinality for consistent queries. + """ + + tenant_id: str + app_id: str + operation_type: str + model_provider: str + model_name: str + node_type: str + + model_config = ConfigDict(extra="forbid", frozen=True) + + def to_dict(self) -> dict[str, AttributeValue]: + return cast( + dict[str, AttributeValue], + { + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "operation_type": self.operation_type, + "model_provider": self.model_provider, + "model_name": self.model_name, + "node_type": self.node_type, + }, + ) + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryEvent", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", + "TokenMetricLabels", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 0000000000..d8b4208c69 --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,72 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to +ensure handlers fire. Each handler delegates to ``core.telemetry.gateway`` +which handles routing, EE-gating, and dispatch. + +All handlers are best-effort: exceptions are caught and logged so that +telemetry failures never break user-facing operations. +""" + +from __future__ import annotations + +import logging + +from events.app_event import app_was_created, app_was_deleted, app_was_updated + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={ + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + }, + ) + except Exception: + logger.warning("Failed to emit app_created telemetry", exc_info=True) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_updated telemetry", exc_info=True) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_deleted telemetry", exc_info=True) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 0000000000..b2f860764f --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,283 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import UTC, datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.baggage import get_all +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + ctx = W3CBaggagePropagator().extract({"baggage": raw}) + return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)} + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + # Ensure we always interpret naive datetimes as UTC instead of local time. + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + else: + dt = dt.astimezone(UTC) + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + self._insecure = insecure + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "") + + # Auto-detect TLS: https:// uses secure, everything else is insecure + insecure = not endpoint.startswith("https://") + + resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + if api_key: + if "authorization" in headers: + logger.warning( + "ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains " + "'authorization'; the API key will take precedence." + ) + headers["authorization"] = f"Bearer {api_key}" + factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for cross-workflow link: %s, span=%s", + effective_trace_correlation, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for child span link: %s, span=%s", + effective_trace_correlation or correlation_id, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 0000000000..f3e5d6d0d6 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,75 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = uuid.UUID(source_id).int & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return uuid.UUID(correlation_id).int + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 0000000000..ffd9a7e2b5 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,421 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from storage reference or inline data. + + If the envelope payload is empty and metadata contains a + ``payload_ref``, the full payload is loaded from object storage + (where the gateway wrote it as JSON). When both the inline + payload and storage resolution fail, a degraded-event marker + is emitted so the gap is observable. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary, or ``{}`` on total failure. + """ + payload = envelope.payload + + # Resolve from object storage when the gateway offloaded a large payload. + if not payload and envelope.metadata: + payload_ref = envelope.metadata.get("payload_ref") + if payload_ref: + try: + payload_bytes = storage.load(payload_ref) + payload = json.loads(payload_bytes.decode("utf-8")) + logger.debug("Loaded payload from storage: key=%s", payload_ref) + except Exception: + logger.warning( + "Failed to load payload from storage: key=%s, event_id=%s", + payload_ref, + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Storage resolution failed or no data available — emit degraded event. + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED, + attributes={ + "tenant_id": envelope.tenant_id, + "dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}", + "dify.telemetry.payload_type": envelope.case, + "dify.telemetry.correlation_id": envelope.event_id, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.mode": payload.get("mode"), + "dify.app.created_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_CREATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "mode": str(payload.get("mode", "")), + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.updated_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_UPDATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.deleted_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_DELETED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + "dify.feedback.created_at": datetime.now(UTC).isoformat(), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event. + + Intentionally a no-op: metrics and structured logs for message runs are + emitted directly by EnterpriseOtelTrace._message_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event. + + Intentionally a no-op: metrics and structured logs for tool executions + are emitted directly by EnterpriseOtelTrace._tool_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event. + + Intentionally a no-op: metrics and structured logs for moderation checks + are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event. + + Intentionally a no-op: metrics and structured logs for suggested questions + are emitted directly by EnterpriseOtelTrace._suggested_question_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event. + + Intentionally a no-op: metrics and structured logs for dataset retrievals + are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event. + + Intentionally a no-op: metrics and structured logs for generate name + operations are emitted directly by EnterpriseOtelTrace._generate_name_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event. + + Intentionally a no-op: metrics and structured logs for prompt generation + operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 0000000000..8cce4a9fcd --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,122 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb..2fba0028f9 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated" # sender: app, kwargs: synced_draft_workflow app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") + +# sender: app +app_was_updated = signal("app-was-updated") + +# sender: app +app_was_deleted = signal("app-was-deleted") diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index c43e99f0f4..ba9758175f 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,10 +1,11 @@ import logging +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) @@ -19,8 +20,9 @@ def handle(sender, **kwargs): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) + provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, + provider_type=provider_type, provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, @@ -30,7 +32,7 @@ def handle(sender, **kwargs): tenant_id=app.tenant_id, tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, + provider_type=provider_type, identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 20852b818e..6769b94cde 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -3,9 +3,9 @@ from typing import cast from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.nodes import BuiltinNodeTypes from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 7b6a73af52..4eed34436a 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -2,7 +2,7 @@ import ssl from datetime import timedelta from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab @@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery: "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), } + if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED: + imports.append("tasks.enterprise_telemetry_task") celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 0000000000..b3cfa01aee --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,50 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter singleton during ``create_app()`` +(single-threaded), registers blinker event handlers, and hooks atexit +for graceful shutdown. + +Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED`` +is false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + + _exporter = EnterpriseExporter(dify_config) + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index a5baa21018..63edbe93e7 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -78,16 +78,24 @@ def init_app(app: DifyApp): protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": if protocol == "grpc": + # Auto-detect TLS: https:// uses secure, everything else is insecure + endpoint = dify_config.OTLP_BASE_ENDPOINT + insecure = not endpoint.startswith("https://") + + # Header field names must consist of lowercase letters, check RFC7540 + grpc_headers = ( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else () + ) + exporter = GRPCSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - # Header field names must consist of lowercase letters, check RFC7540 - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) metric_exporter = GRPCMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) else: headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 9a34acb0c1..120febecfb 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -10,7 +10,7 @@ def init_app(app: DifyApp): from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from dify_graph.model_runtime.errors.invoke import InvokeRateLimitError + from graphon.model_runtime.errors.invoke import InvokeRateLimitError def before_send(event, hint): if "exc_info" in hint: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index a94d75ec76..64ff0f0674 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,10 +13,10 @@ from typing import Any from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value +from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -60,7 +60,7 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" - model.status = data.get("status") or "running" # Default status if missing + model.status = WorkflowNodeExecutionStatus(data.get("status") or "running") model.title = data.get("title") or "" created_by_role_val = data.get("created_by_role") try: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index bdfc81bd1c..5208f8f37e 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -22,10 +22,10 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string +from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index c58aa6adbb..ea4a2b3dd1 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -7,11 +7,11 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index d84c0bc432..976b5db8e3 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -17,14 +17,14 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier +from graphon.entities import WorkflowNodeExecution +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, @@ -304,35 +304,39 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) # Don't raise - LogStore write succeeded, SQL is just a backup - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. Uses LogStore SQL query with window function to get the latest version of each node execution. This ensures we only get the most recent version of each node execution record. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of workflow node execution instances Note: This method uses ROW_NUMBER() window function partitioned by node_execution_id to get the latest version (highest log_version) of each node execution. """ - logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + logger.debug( + "get_by_workflow_execution: workflow_execution_id=%s, order_config=%s", + workflow_execution_id, + order_config, + ) # Build SQL query with deduplication using window function # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) # ensures we get the latest version of each node execution # Escape parameters to prevent SQL injection - escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_workflow_execution_id = escape_identifier(workflow_execution_id) escaped_tenant_id = escape_identifier(self._tenant_id) # Build ORDER BY clause for outer query @@ -360,7 +364,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{escaped_workflow_run_id}' + WHERE workflow_run_id='{escaped_workflow_execution_id}' AND tenant_id='{escaped_tenant_id}' {app_id_filter} ) t @@ -391,5 +395,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): return executions except Exception: - logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + logger.exception( + "Failed to retrieve node executions from LogStore: workflow_execution_id=%s", + workflow_execution_id, + ) raise diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py index 164db7c275..c671e8b409 100644 --- a/api/extensions/otel/parser/__init__.py +++ b/api/extensions/otel/parser/__init__.py @@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set OpenTelemetry span attributes according to semantic conventions. """ -from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.parser.llm import LLMNodeOTelParser from extensions.otel.parser.retrieval import RetrievalNodeOTelParser from extensions.otel.parser.tool import ToolNodeOTelParser @@ -17,4 +17,5 @@ __all__ = [ "RetrievalNodeOTelParser", "ToolNodeOTelParser", "safe_json_dumps", + "should_include_content", ] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 544ef3fe18..eefcaa126e 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -1,5 +1,10 @@ """ Base parser interface and utilities for OpenTelemetry node parsers. + +Content gating: ``should_include_content()`` controls whether content-bearing +span attributes (inputs, outputs, prompts, completions, documents) are written. +Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when +``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged. """ import json @@ -9,12 +14,23 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.file.models import File -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment +from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +from graphon.enums import BuiltinNodeTypes +from graphon.file.models import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment + + +def should_include_content() -> bool: + """Return True if content should be written to spans. + + CE (ENTERPRISE_ENABLED=False): always True — no behaviour change. + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + return dify_config.ENTERPRISE_INCLUDE_CONTENT def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: @@ -101,10 +117,11 @@ class DefaultNodeOTelParser: # Extract inputs and outputs from result_event if result_event and result_event.node_run_result: node_run_result = result_event.node_run_result - if node_run_result.inputs: - span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) - if node_run_result.outputs: - span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + if should_include_content(): + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) if error: span.record_exception(error) diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 3da9a9e97d..ec3c78a12d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -8,10 +8,10 @@ from typing import Any from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index dd658b250b..56672d1fd4 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -8,11 +8,11 @@ from typing import Any from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index f4e6a18b4d..75ddbba448 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -4,12 +4,12 @@ Parser for tool nodes that captures tool-specific metadata. from opentelemetry.trace import Span -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData class ToolNodeOTelParser: diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358d..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py deleted file mode 100644 index cb07ba58ae..0000000000 --- a/api/factories/file_factory.py +++ /dev/null @@ -1,618 +0,0 @@ -import logging -import mimetypes -import os -import re -import urllib.parse -import uuid -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import httpx -from sqlalchemy import select -from sqlalchemy.orm import Session -from werkzeug.http import parse_options_header - -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.helper import ssrf_proxy -from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers -from extensions.ext_database import db -from models import MessageFile, ToolFile, UploadFile - -logger = logging.getLogger(__name__) - - -def build_from_message_files( - *, - message_files: Sequence["MessageFile"], - tenant_id: str, - config: FileUploadConfig | None = None, -) -> Sequence[File]: - results = [ - build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) - for file in message_files - if file.belongs_to != FileBelongsTo.ASSISTANT - ] - return results - - -def build_from_message_file( - *, - message_file: "MessageFile", - tenant_id: str, - config: FileUploadConfig | None, -): - mapping = { - "transfer_method": message_file.transfer_method, - "url": message_file.url, - "type": message_file.type, - } - - # Only include id if it exists (message_file has been committed to DB) - if message_file.id: - mapping["id"] = message_file.id - - # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE: - mapping["tool_file_id"] = message_file.upload_file_id - else: - mapping["upload_file_id"] = message_file.upload_file_id - - return build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - - -def build_from_mapping( - *, - mapping: Mapping[str, Any], - tenant_id: str, - config: FileUploadConfig | None = None, - strict_type_validation: bool = False, -) -> File: - transfer_method_value = mapping.get("transfer_method") - if not transfer_method_value: - raise ValueError("transfer_method is required in file mapping") - transfer_method = FileTransferMethod.value_of(transfer_method_value) - - build_functions: dict[FileTransferMethod, Callable] = { - FileTransferMethod.LOCAL_FILE: _build_from_local_file, - FileTransferMethod.REMOTE_URL: _build_from_remote_url, - FileTransferMethod.TOOL_FILE: _build_from_tool_file, - FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, - } - - build_func = build_functions.get(transfer_method) - if not build_func: - raise ValueError(f"Invalid file transfer method: {transfer_method}") - - file: File = build_func( - mapping=mapping, - tenant_id=tenant_id, - transfer_method=transfer_method, - strict_type_validation=strict_type_validation, - ) - - if config and not _is_file_valid_with_config( - input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension or "", - file_transfer_method=file.transfer_method, - config=config, - ): - raise ValueError(f"File validation failed for file: {file.filename}") - - return file - - -def build_from_mappings( - *, - mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None = None, - tenant_id: str, - strict_type_validation: bool = False, -) -> Sequence[File]: - # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. - # Implement batch processing to reduce database load when handling multiple files. - # Filter out None/empty mappings to avoid errors - def is_valid_mapping(m: Mapping[str, Any]) -> bool: - if not m or not m.get("transfer_method"): - return False - # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None - transfer_method = m.get("transfer_method") - if transfer_method == FileTransferMethod.REMOTE_URL: - url = m.get("url") or m.get("remote_url") - if not url: - return False - return True - - valid_mappings = [m for m in mappings if is_valid_mapping(m)] - files = [ - build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - strict_type_validation=strict_type_validation, - ) - for mapping in valid_mappings - ] - - if ( - config - # If image config is set. - and config.image_config - # And the number of image files exceeds the maximum limit - and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits - ): - raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config and config.number_limits and len(files) > config.number_limits: - raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") - - return files - - -def _build_from_local_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if not upload_file_id: - raise ValueError("Invalid upload file id") - # check if upload_file_id is a valid uuid - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - row = db.session.scalar(stmt) - if row is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - specified_type = mapping.get("type", "custom") - - if strict_type_validation and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - related_id=mapping.get("upload_file_id"), - size=row.size, - storage_key=row.key, - ) - - -def _build_from_remote_url( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if upload_file_id: - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - upload_file = db.session.scalar(stmt) - if upload_file is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type( - extension="." + upload_file.extension, mime_type=upload_file.mime_type - ) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - related_id=mapping.get("upload_file_id"), - size=upload_file.size, - storage_key=upload_file.key, - ) - url = mapping.get("url") or mapping.get("remote_url") - if not url: - raise ValueError("Invalid file url") - - mime_type, filename, file_size = _get_remote_file_info(url) - extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - - detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=url, - mime_type=mime_type, - extension=extension, - size=file_size, - storage_key="", - ) - - -def _extract_filename(url_path: str, content_disposition: str | None) -> str | None: - filename: str | None = None - # Try to extract from Content-Disposition header first - if content_disposition: - # Manually extract filename* parameter since parse_options_header doesn't support it - filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) - if filename_star_match: - raw_star = filename_star_match.group(1).strip() - # Remove trailing quotes if present - raw_star = raw_star.removesuffix('"') - # format: charset'lang'value - try: - parts = raw_star.split("'", 2) - charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" - value = parts[2] if len(parts) == 3 else parts[-1] - filename = urllib.parse.unquote(value, encoding=charset, errors="replace") - except Exception: - # Fallback: try to extract value after the last single quote - if "''" in raw_star: - filename = urllib.parse.unquote(raw_star.split("''")[-1]) - else: - filename = urllib.parse.unquote(raw_star) - - if not filename: - # Fallback to regular filename parameter - _, params = parse_options_header(content_disposition) - raw = params.get("filename") - if raw: - # Strip surrounding quotes and percent-decode if present - if len(raw) >= 2 and raw[0] == raw[-1] == '"': - raw = raw[1:-1] - filename = urllib.parse.unquote(raw) - # Fallback to URL path if no filename from header - if not filename: - candidate = os.path.basename(url_path) - filename = urllib.parse.unquote(candidate) if candidate else None - # Defense-in-depth: ensure basename only - if filename: - filename = os.path.basename(filename) - # Return None if filename is empty or only whitespace - if not filename or not filename.strip(): - filename = None - return filename or None - - -def _guess_mime_type(filename: str) -> str: - """Guess MIME type from filename, returning empty string if None.""" - guessed_mime, _ = mimetypes.guess_type(filename) - return guessed_mime or "" - - -def _get_remote_file_info(url: str): - file_size = -1 - parsed_url = urllib.parse.urlparse(url) - url_path = parsed_url.path - filename = os.path.basename(url_path) - - # Initialize mime_type from filename as fallback - mime_type = _guess_mime_type(filename) - - resp = ssrf_proxy.head(url, follow_redirects=True) - if resp.status_code == httpx.codes.OK: - content_disposition = resp.headers.get("Content-Disposition") - extracted_filename = _extract_filename(url_path, content_disposition) - if extracted_filename: - filename = extracted_filename - mime_type = _guess_mime_type(filename) - file_size = int(resp.headers.get("Content-Length", file_size)) - # Fallback to Content-Type header if mime_type is still empty - if not mime_type: - mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() - - if not filename: - extension = mimetypes.guess_extension(mime_type) or ".bin" - filename = f"{uuid.uuid4().hex}{extension}" - if not mime_type: - mime_type = _guess_mime_type(filename) - - return mime_type, filename, file_size - - -def _build_from_tool_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - # Backward/interop compatibility: allow tool_file_id to come from related_id or URL - tool_file_id = mapping.get("tool_file_id") - - if not tool_file_id: - raise ValueError(f"ToolFile {tool_file_id} not found") - tool_file = db.session.scalar( - select(ToolFile).where( - ToolFile.id == tool_file_id, - ToolFile.tenant_id == tenant_id, - ) - ) - - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") - - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - - detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - tenant_id=tenant_id, - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) - - -def _build_from_datasource_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - datasource_file_id = mapping.get("datasource_file_id") - if not datasource_file_id: - raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = db.session.scalar( - select(UploadFile).where( - UploadFile.id == datasource_file_id, - UploadFile.tenant_id == tenant_id, - ) - ) - - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("datasource_file_id"), - tenant_id=tenant_id, - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - related_id=datasource_file.id, - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) - - -def _is_file_valid_with_config( - *, - input_file_type: str, - file_extension: str, - file_transfer_method: FileTransferMethod, - config: FileUploadConfig, -) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions - if file_transfer_method == FileTransferMethod.TOOL_FILE: - return True - - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): - return False - - if ( - input_file_type == FileType.CUSTOM - and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions - ): - return False - - if input_file_type == FileType.IMAGE: - if ( - config.image_config - and config.image_config.transfer_methods - and file_transfer_method not in config.image_config.transfer_methods - ): - return False - elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: - return False - - return True - - -def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the possible actual type of the file based on the extension and mime_type - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = _get_file_type_by_mimetype(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - extension = extension.lstrip(".") - if extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - elif extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - elif extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - elif extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM - - -class StorageKeyLoader: - """FileKeyLoader load the storage key from database for a list of files. - This loader is batched, the database query count is constant regardless of the input size. - """ - - def __init__(self, session: Session, tenant_id: str): - self._session = session - self._tenant_id = tenant_id - - def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: - stmt = select(UploadFile).where( - UploadFile.id.in_(upload_file_ids), - UploadFile.tenant_id == self._tenant_id, - ) - - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: - stmt = select(ToolFile).where( - ToolFile.id.in_(tool_file_ids), - ToolFile.tenant_id == self._tenant_id, - ) - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def load_storage_keys(self, files: Sequence[File]): - """Loads storage keys for a sequence of files by retrieving the corresponding - `UploadFile` or `ToolFile` records from the database based on their transfer method. - - This method doesn't modify the input sequence structure but updates the `_storage_key` - property of each file object by extracting the relevant key from its database record. - - Performance note: This is a batched operation where database query count remains constant - regardless of input size. However, for optimal performance, input sequences should contain - fewer than 1000 files. For larger collections, split into smaller batches and process each - batch separately. - """ - - upload_file_ids: list[uuid.UUID] = [] - tool_file_ids: list[uuid.UUID] = [] - for file in files: - related_model_id = file.related_id - if file.related_id is None: - raise ValueError("file id should not be None.") - if file.tenant_id != self._tenant_id: - err_msg = ( - f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" - ) - raise ValueError(err_msg) - model_id = uuid.UUID(related_model_id) - - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(model_id) - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(model_id) - - tool_files = self._load_tool_files(tool_file_ids) - upload_files = self._load_upload_files(upload_file_ids) - for file in files: - model_id = uuid.UUID(file.related_id) - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(model_id) - if upload_file_row is None: - raise ValueError(f"Upload file not found for id: {model_id}") - file.storage_key = upload_file_row.key - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(model_id) - if tool_file_row is None: - raise ValueError(f"Tool file not found for id: {model_id}") - file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/__init__.py b/api/factories/file_factory/__init__.py new file mode 100644 index 0000000000..ae0cd972ec --- /dev/null +++ b/api/factories/file_factory/__init__.py @@ -0,0 +1,18 @@ +"""Workflow file factory package. + +This package normalizes workflow-layer file payloads into graph-layer ``File`` +values. It keeps tenancy and ownership checks in the application layer and +exports the workflow-facing file builders for callers. +""" + +from .builders import build_from_mapping, build_from_mappings +from .message_files import build_from_message_file, build_from_message_files +from .storage_keys import StorageKeyLoader + +__all__ = [ + "StorageKeyLoader", + "build_from_mapping", + "build_from_mappings", + "build_from_message_file", + "build_from_message_files", +] diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py new file mode 100644 index 0000000000..bc87510d43 --- /dev/null +++ b/api/factories/file_factory/builders.py @@ -0,0 +1,329 @@ +"""Core builders for workflow file mappings.""" + +from __future__ import annotations + +import mimetypes +import uuid +from collections.abc import Mapping, Sequence +from typing import Any + +from sqlalchemy import select + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers +from graphon.file.file_factory import standardize_file_type +from models import ToolFile, UploadFile + +from .common import resolve_mapping_file_id +from .remote import get_remote_file_info +from .validation import is_file_valid_with_config + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + transfer_method_value = mapping.get("transfer_method") + if not transfer_method_value: + raise ValueError("transfer_method is required in file mapping") + + transfer_method = FileTransferMethod.value_of(transfer_method_value) + build_func = _get_build_function(transfer_method) + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + + if config and not is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension or "", + file_transfer_method=file.transfer_method, + config=config, + ): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None = None, + tenant_id: str, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. + valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)] + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + for mapping in valid_mappings + ] + + if ( + config + and config.image_config + and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config and config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _get_build_function(transfer_method: FileTransferMethod): + build_functions = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, + } + build_func = build_functions.get(transfer_method) + if build_func is None: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + return build_func + + +def _resolve_file_type( + *, + detected_file_type: FileType, + specified_type: str | None, + strict_type_validation: bool, +) -> FileType: + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + if specified_type and specified_type != "custom": + return FileType(specified_type) + return detected_file_type + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") + + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if upload_file_id: + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) + + url = mapping.get("url") or mapping.get("remote_url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") + detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=filename, + type=file_type, + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id") + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") + + stmt = select(ToolFile).where( + ToolFile.id == tool_file_id, + ToolFile.tenant_id == tenant_id, + ) + tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + + +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") + + stmt = select(UploadFile).where( + UploadFile.id == datasource_file_id, + UploadFile.tenant_id == tenant_id, + ) + datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) + + +def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: + if not mapping or not mapping.get("transfer_method"): + return False + + if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL: + url = mapping.get("url") or mapping.get("remote_url") + if not url: + return False + + return True diff --git a/api/factories/file_factory/common.py b/api/factories/file_factory/common.py new file mode 100644 index 0000000000..2e1c95ab3f --- /dev/null +++ b/api/factories/file_factory/common.py @@ -0,0 +1,27 @@ +"""Shared helpers for workflow file factory modules.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.workflow.file_reference import resolve_file_record_id + + +def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None: + """Resolve historical file identifiers from persisted mapping payloads. + + Workflow and model payloads can outlive file schema changes. Older rows may + still carry concrete identifiers in legacy fields such as ``related_id``, + while newer payloads use opaque references. Keep this compatibility lookup in + the factory layer so historical data remains readable without reintroducing + storage details into graph-layer ``File`` values. + """ + + for key in (*keys, "reference", "related_id"): + raw_value = mapping.get(key) + if isinstance(raw_value, str) and raw_value: + resolved_value = resolve_file_record_id(raw_value) + if resolved_value: + return resolved_value + return None diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py new file mode 100644 index 0000000000..4b3d514238 --- /dev/null +++ b/api/factories/file_factory/message_files.py @@ -0,0 +1,59 @@ +"""Adapters from persisted message files to graph-layer file values.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from core.app.file_access import FileAccessControllerProtocol +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig +from models import MessageFile + +from .builders import build_from_mapping + + +def build_from_message_files( + *, + message_files: Sequence[MessageFile], + tenant_id: str, + config: FileUploadConfig | None = None, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + return [ + build_from_message_file( + message_file=message_file, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) + for message_file in message_files + if message_file.belongs_to != FileBelongsTo.ASSISTANT + ] + + +def build_from_message_file( + *, + message_file: MessageFile, + tenant_id: str, + config: FileUploadConfig | None, + access_controller: FileAccessControllerProtocol, +) -> File: + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "type": message_file.type, + } + + if message_file.id: + mapping["id"] = message_file.id + + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: + mapping["tool_file_id"] = message_file.upload_file_id + else: + mapping["upload_file_id"] = message_file.upload_file_id + + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py new file mode 100644 index 0000000000..e5a7186007 --- /dev/null +++ b/api/factories/file_factory/remote.py @@ -0,0 +1,91 @@ +"""Remote file metadata helpers used by workflow file normalization. + +These helpers are part of the ``factories.file_factory`` package surface +because both workflow builders and tests rely on the same RFC5987 filename +parsing and HEAD-response normalization rules. +""" + +from __future__ import annotations + +import mimetypes +import os +import re +import urllib.parse +import uuid + +import httpx +from werkzeug.http import parse_options_header + +from core.helper import ssrf_proxy + + +def extract_filename(url_path: str, content_disposition: str | None) -> str | None: + """Extract a safe filename from Content-Disposition or the request URL path.""" + filename: str | None = None + if content_disposition: + filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) + if filename_star_match: + raw_star = filename_star_match.group(1).strip() + raw_star = raw_star.removesuffix('"') + try: + parts = raw_star.split("'", 2) + charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" + value = parts[2] if len(parts) == 3 else parts[-1] + filename = urllib.parse.unquote(value, encoding=charset, errors="replace") + except Exception: + if "''" in raw_star: + filename = urllib.parse.unquote(raw_star.split("''")[-1]) + else: + filename = urllib.parse.unquote(raw_star) + + if not filename: + _, params = parse_options_header(content_disposition) + raw = params.get("filename") + if raw: + if len(raw) >= 2 and raw[0] == raw[-1] == '"': + raw = raw[1:-1] + filename = urllib.parse.unquote(raw) + + if not filename: + candidate = os.path.basename(url_path) + filename = urllib.parse.unquote(candidate) if candidate else None + + if filename: + filename = os.path.basename(filename) + if not filename or not filename.strip(): + filename = None + + return filename or None + + +def _guess_mime_type(filename: str) -> str: + guessed_mime, _ = mimetypes.guess_type(filename) + return guessed_mime or "" + + +def get_remote_file_info(url: str) -> tuple[str, str, int]: + """Resolve remote file metadata with SSRF-safe HEAD probing.""" + file_size = -1 + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + mime_type = _guess_mime_type(filename) + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + content_disposition = resp.headers.get("Content-Disposition") + extracted_filename = extract_filename(url_path, content_disposition) + if extracted_filename: + filename = extracted_filename + mime_type = _guess_mime_type(filename) + file_size = int(resp.headers.get("Content-Length", file_size)) + if not mime_type: + mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + + if not filename: + extension = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex}{extension}" + if not mime_type: + mime_type = _guess_mime_type(filename) + + return mime_type, filename, file_size diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py new file mode 100644 index 0000000000..dba4c84407 --- /dev/null +++ b/api/factories/file_factory/storage_keys.py @@ -0,0 +1,106 @@ +"""Batched storage-key hydration for workflow files.""" + +from __future__ import annotations + +import uuid +from collections.abc import Mapping, Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference, parse_file_reference +from graphon.file import File, FileTransferMethod +from models import ToolFile, UploadFile + + +class StorageKeyLoader: + """Load storage keys for files with a constant number of database queries.""" + + _session: Session + _tenant_id: str + _access_controller: FileAccessControllerProtocol + + def __init__( + self, + session: Session, + tenant_id: str, + access_controller: FileAccessControllerProtocol, + ) -> None: + self._session = session + self._tenant_id = tenant_id + self._access_controller = access_controller + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_upload_file_filters(stmt) + return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_tool_file_filters(stmt) + return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)} + + def load_storage_keys(self, files: Sequence[File]) -> None: + """Hydrate storage keys by loading their backing file rows in batches. + + The sequence shape is preserved. Each file is updated in place with a + canonical record reference and storage key loaded from an authorized + database row. Tenant scoping is enforced by this loader's context + rather than by embedding tenant identity or storage paths inside + graph-layer ``File`` values. + + For best performance, prefer batches smaller than 1000 files. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(f"Upload file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(upload_file_row.id), + ) + file.storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(f"Tool file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(tool_file_row.id), + ) + file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py new file mode 100644 index 0000000000..4c4f6150e4 --- /dev/null +++ b/api/factories/file_factory/validation.py @@ -0,0 +1,44 @@ +"""Validation helpers for workflow file inputs.""" + +from __future__ import annotations + +from graphon.file import FileTransferMethod, FileType, FileUploadConfig + + +def is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) + # These are internally generated and should bypass user upload restrictions + if file_transfer_method == FileTransferMethod.TOOL_FILE: + return True + + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): + return False + + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): + return False + + if input_file_type == FileType.IMAGE: + if ( + config.image_config + and config.image_config.transfer_methods + and file_transfer_method not in config.image_config.transfer_methods + ): + return False + elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: + return False + + return True diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 14a56bf4a2..fd7acb14d3 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,75 +1,51 @@ +"""Compatibility factory for non-graph variable bootstrapping. + +Graph runtime segment/variable conversions live under `graphon.variables`. +This module keeps the application-layer mapping helpers and re-exports the +shared conversion functions for legacy callers and tests. +""" + from collections.abc import Mapping, Sequence from typing import Any, cast -from uuid import uuid4 from configs import dify_config -from dify_graph.constants import ( +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from dify_graph.file import File -from dify_graph.variables.exc import VariableError -from dify_graph.variables.segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, +from graphon.variables.exc import VariableError +from graphon.variables.factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import ( - ArrayAnyVariable, +from graphon.variables.types import SegmentType +from graphon.variables.variables import ( ArrayBooleanVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, BooleanVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, StringVariable, VariableBase, ) - -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -# Define the constant -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} +__all__ = [ + "TypeMismatchError", + "UnsupportedSegmentTypeError", + "build_conversation_variable_from_mapping", + "build_environment_variable_from_mapping", + "build_pipeline_variable_from_mapping", + "build_segment", + "build_segment_with_type", + "segment_to_variable", +] def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: @@ -135,172 +111,3 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if not result.selector: result = result.model_copy(update={"selector": selector}) return cast(VariableBase, result) - - -def build_segment(value: Any, /) -> Segment: - # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` - # below - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - elif len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_segment_factory: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - # Array types - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """ - Build a segment with explicit type checking. - - This function creates a segment from a value while enforcing type compatibility - with the specified segment_type. It provides stricter type validation compared - to the standard build_segment function. - - Args: - segment_type: The expected SegmentType for the resulting segment - value: The value to be converted into a segment - - Returns: - Segment: A segment instance of the appropriate type - - Raises: - TypeMismatchError: If the value type doesn't match the expected segment_type - - Special Cases: - - For empty list [] values, if segment_type is array[*], returns the corresponding array type - - Type validation is performed before segment creation - - Examples: - >>> build_segment_with_type(SegmentType.STRING, "hello") - StringSegment(value="hello") - - >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) - ArrayStringSegment(value=[]) - - >>> build_segment_with_type(SegmentType.STRING, 123) - # Raises TypeMismatchError - """ - # Handle None values - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - # Handle empty list special case for array types - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - elif segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - elif segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - elif segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - elif segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - elif segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - # Type compatibility checking - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _segment_factory[segment_type] - return segment_class(value_type=segment_type, value=value) - elif segment_type == SegmentType.NUMBER and inferred_type in ( - SegmentType.INTEGER, - SegmentType.FLOAT, - ): - segment_class = _segment_factory[inferred_type] - return segment_class(value_type=inferred_type, value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value_type=segment.value_type, - value=segment.value, - selector=list(selector), - ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index ac7c5376fb..b5acbbbcb4 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from dify_graph.variables.segments import Segment -from dify_graph.variables.types import SegmentType +from graphon.variables.segments import Segment +from graphon.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index a5c7ddbb11..801949747e 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from dify_graph.file import File +from graphon.file import File JSONValue: TypeAlias = Any @@ -311,7 +311,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValue) -> JSONValue: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 7ee628726b..4e201e66e6 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -5,7 +5,7 @@ from datetime import datetime from flask_restx import fields from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from dify_graph.file import helpers as file_helpers +from graphon.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 428f92ed33..86c4f285cd 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -7,8 +7,8 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel -from dify_graph.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile +from graphon.file import File JSONValueType: TypeAlias = JSONValue @@ -133,7 +133,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValueType) -> JSONValueType: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/raws.py b/api/fields/raws.py index 318dedc25c..ee6f53b360 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,6 +1,6 @@ from flask_restx import fields -from dify_graph.file import File +from graphon.file import File class FilesContainedField(fields.Raw): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 7ce2139687..f9b5e98936 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields from core.helper import encrypter -from dify_graph.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/dify_graph/README.md b/api/graphon/README.md similarity index 98% rename from api/dify_graph/README.md rename to api/graphon/README.md index 2fc5b8b890..725f122cd8 100644 --- a/api/dify_graph/README.md +++ b/api/graphon/README.md @@ -114,7 +114,7 @@ The codebase enforces strict layering via import-linter: 1. Inherit from `BaseNode` or appropriate base class 1. Implement `_run()` method 1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/dify_graph/nodes/` +1. Add tests in `tests/unit_tests/graphon/nodes/` ### Implementing a Custom Layer diff --git a/api/dify_graph/model_runtime/__init__.py b/api/graphon/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/__init__.py rename to api/graphon/__init__.py diff --git a/api/dify_graph/entities/__init__.py b/api/graphon/entities/__init__.py similarity index 100% rename from api/dify_graph/entities/__init__.py rename to api/graphon/entities/__init__.py diff --git a/api/dify_graph/entities/base_node_data.py b/api/graphon/entities/base_node_data.py similarity index 98% rename from api/dify_graph/entities/base_node_data.py rename to api/graphon/entities/base_node_data.py index 47b37c9daf..e8267043a9 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/graphon/entities/base_node_data.py @@ -8,8 +8,8 @@ from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, model_validator -from dify_graph.entities.exc import DefaultValueTypeError -from dify_graph.enums import ErrorStrategy, NodeType +from graphon.entities.exc import DefaultValueTypeError +from graphon.enums import ErrorStrategy, NodeType # Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. _NumberType = Union[int, float] diff --git a/api/dify_graph/entities/exc.py b/api/graphon/entities/exc.py similarity index 100% rename from api/dify_graph/entities/exc.py rename to api/graphon/entities/exc.py diff --git a/api/dify_graph/entities/graph_config.py b/api/graphon/entities/graph_config.py similarity index 89% rename from api/dify_graph/entities/graph_config.py rename to api/graphon/entities/graph_config.py index 36f7b94e82..392241c631 100644 --- a/api/dify_graph/entities/graph_config.py +++ b/api/graphon/entities/graph_config.py @@ -4,7 +4,7 @@ import sys from pydantic import TypeAdapter, with_config -from dify_graph.entities.base_node_data import BaseNodeData +from graphon.entities.base_node_data import BaseNodeData if sys.version_info >= (3, 12): from typing import TypedDict diff --git a/api/dify_graph/entities/graph_init_params.py b/api/graphon/entities/graph_init_params.py similarity index 100% rename from api/dify_graph/entities/graph_init_params.py rename to api/graphon/entities/graph_init_params.py diff --git a/api/dify_graph/entities/pause_reason.py b/api/graphon/entities/pause_reason.py similarity index 80% rename from api/dify_graph/entities/pause_reason.py rename to api/graphon/entities/pause_reason.py index 86d8c8ca16..ba2973fd45 100644 --- a/api/dify_graph/entities/pause_reason.py +++ b/api/graphon/entities/pause_reason.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, Field -from dify_graph.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.entities import FormInput, UserAction class PauseReasonType(StrEnum): @@ -18,7 +18,6 @@ class HumanInputRequired(BaseModel): form_content: str inputs: list[FormInput] = Field(default_factory=list) actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False node_id: str node_title: str @@ -33,13 +32,6 @@ class HumanInputRequired(BaseModel): # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None - class SchedulingPause(BaseModel): TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE diff --git a/api/dify_graph/entities/workflow_execution.py b/api/graphon/entities/workflow_execution.py similarity index 79% rename from api/dify_graph/entities/workflow_execution.py rename to api/graphon/entities/workflow_execution.py index 459ac46415..b8de7eed1a 100644 --- a/api/dify_graph/entities/workflow_execution.py +++ b/api/graphon/entities/workflow_execution.py @@ -1,26 +1,23 @@ """ Domain entities for workflow execution. -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. +Models describe graph runtime state and avoid infrastructure-specific details. """ from __future__ import annotations from collections.abc import Mapping -from datetime import datetime +from datetime import UTC, datetime from typing import Any from pydantic import BaseModel, Field -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from libs.datetime_utils import naive_utc_now +from graphon.enums import WorkflowExecutionStatus, WorkflowType class WorkflowExecution(BaseModel): """ - Domain model for workflow execution based on WorkflowRun but without - user, tenant, and app attributes. + Domain model for a workflow execution within the graph runtime. """ id_: str = Field(...) @@ -47,7 +44,7 @@ class WorkflowExecution(BaseModel): Calculate elapsed time in seconds. If workflow is not finished, use current time. """ - end_time = self.finished_at or naive_utc_now() + end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) return (end_time - self.started_at).total_seconds() @classmethod diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/graphon/entities/workflow_node_execution.py similarity index 85% rename from api/dify_graph/entities/workflow_node_execution.py rename to api/graphon/entities/workflow_node_execution.py index bc7e0d02e5..5458572e7e 100644 --- a/api/dify_graph/entities/workflow_node_execution.py +++ b/api/graphon/entities/workflow_node_execution.py @@ -1,9 +1,8 @@ """ Domain entities for workflow node execution. -This module contains the domain model for workflow node execution, which is used -by the core workflow module. These models are independent of the storage mechanism -and don't contain implementation details like tenant_id, app_id, etc. +These models capture node-level execution state for the graph runtime without +describing storage or application-layer concerns. """ from collections.abc import Mapping @@ -12,20 +11,15 @@ from typing import Any from pydantic import BaseModel, Field, PrivateAttr -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class WorkflowNodeExecution(BaseModel): """ Domain model for workflow node execution. - This model represents the core business entity of a node execution, - without implementation details like tenant_id, app_id, etc. - - Note: User/context-specific fields (triggered_from, created_by, created_by_role) - have been moved to the repository implementation to keep the domain model clean. - These fields are still accepted in the constructor for backward compatibility, - but they are not stored in the model. + This model represents the graph-level record of a node execution and + contains only execution state relevant to the runtime. """ # --------- Core identification fields --------- @@ -41,7 +35,7 @@ class WorkflowNodeExecution(BaseModel): # In most scenarios, `id` should be used as the primary identifier. node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow diff --git a/api/dify_graph/entities/workflow_start_reason.py b/api/graphon/entities/workflow_start_reason.py similarity index 100% rename from api/dify_graph/entities/workflow_start_reason.py rename to api/graphon/entities/workflow_start_reason.py diff --git a/api/dify_graph/enums.py b/api/graphon/enums.py similarity index 93% rename from api/dify_graph/enums.py rename to api/graphon/enums.py index cfb135cbb0..bbc973abe5 100644 --- a/api/dify_graph/enums.py +++ b/api/graphon/enums.py @@ -10,30 +10,6 @@ class NodeState(StrEnum): SKIPPED = "skipped" -class SystemVariableKey(StrEnum): - """ - System Variables. - """ - - QUERY = "query" - FILES = "files" - CONVERSATION_ID = "conversation_id" - USER_ID = "user_id" - DIALOGUE_COUNT = "dialogue_count" - APP_ID = "app_id" - WORKFLOW_ID = "workflow_id" - WORKFLOW_EXECUTION_ID = "workflow_run_id" - TIMESTAMP = "timestamp" - # RAG Pipeline - DOCUMENT_ID = "document_id" - ORIGINAL_DOCUMENT_ID = "original_document_id" - BATCH = "batch" - DATASET_ID = "dataset_id" - DATASOURCE_TYPE = "datasource_type" - DATASOURCE_INFO = "datasource_info" - INVOKE_FROM = "invoke_from" - - NodeType: TypeAlias = str @@ -41,7 +17,7 @@ class BuiltinNodeTypes: """Built-in node type string constants. `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `dify_graph`; downstream packages can + only exposes the built-in values shipped by `graphon`; downstream packages can use additional strings without extending this class. """ diff --git a/api/dify_graph/errors.py b/api/graphon/errors.py similarity index 89% rename from api/dify_graph/errors.py rename to api/graphon/errors.py index 463d17713e..7eb007524d 100644 --- a/api/dify_graph/errors.py +++ b/api/graphon/errors.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.base.node import Node +from graphon.nodes.base.node import Node class WorkflowNodeRunFailedError(Exception): diff --git a/api/dify_graph/file/__init__.py b/api/graphon/file/__init__.py similarity index 75% rename from api/dify_graph/file/__init__.py rename to api/graphon/file/__init__.py index 44749ebec3..4908ae9795 100644 --- a/api/dify_graph/file/__init__.py +++ b/api/graphon/file/__init__.py @@ -1,5 +1,6 @@ from .constants import FILE_MODEL_IDENTITY from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .file_factory import get_file_type_by_mime_type, standardize_file_type from .models import ( File, FileUploadConfig, @@ -16,4 +17,6 @@ __all__ = [ "FileType", "FileUploadConfig", "ImageConfig", + "get_file_type_by_mime_type", + "standardize_file_type", ] diff --git a/api/graphon/file/constants.py b/api/graphon/file/constants.py new file mode 100644 index 0000000000..56b95b5f0d --- /dev/null +++ b/api/graphon/file/constants.py @@ -0,0 +1,48 @@ +from collections.abc import Iterable +from typing import Any + +# TODO(QuantumGhost): Refactor variable type identification. Instead of directly +# comparing `dify_model_identity` with constants throughout the codebase, extract +# this logic into a dedicated function. This would encapsulate the implementation +# details of how different variable types are identified. +FILE_MODEL_IDENTITY = "__dify__file__" +DEFAULT_MIME_TYPE = "application/octet-stream" +DEFAULT_EXTENSION = ".bin" + + +def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]: + normalized = {extension.lower() for extension in extensions} + return frozenset(normalized | {extension.upper() for extension in normalized}) + + +IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"}) +VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"}) +AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"}) +DOCUMENT_EXTENSIONS = _with_case_variants( + { + "txt", + "markdown", + "md", + "mdx", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "vtt", + "properties", + "doc", + "docx", + "csv", + "eml", + "msg", + "ppt", + "pptx", + "xml", + "epub", + } +) + + +def maybe_file_object(o: Any) -> bool: + return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/dify_graph/file/enums.py b/api/graphon/file/enums.py similarity index 100% rename from api/dify_graph/file/enums.py rename to api/graphon/file/enums.py diff --git a/api/graphon/file/file_factory.py b/api/graphon/file/file_factory.py new file mode 100644 index 0000000000..3d20b9377d --- /dev/null +++ b/api/graphon/file/file_factory.py @@ -0,0 +1,39 @@ +from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from .enums import FileType + + +def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: + """ + Infer the actual file type from extension and mime type. + """ + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = get_file_type_by_mime_type(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + normalized_extension = extension.lstrip(".") + if normalized_extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + if normalized_extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + if normalized_extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + if normalized_extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + return None + + +def get_file_type_by_mime_type(mime_type: str) -> FileType: + if "image" in mime_type: + return FileType.IMAGE + if "video" in mime_type: + return FileType.VIDEO + if "audio" in mime_type: + return FileType.AUDIO + if "text" in mime_type or "pdf" in mime_type: + return FileType.DOCUMENT + return FileType.CUSTOM diff --git a/api/dify_graph/file/file_manager.py b/api/graphon/file/file_manager.py similarity index 74% rename from api/dify_graph/file/file_manager.py rename to api/graphon/file/file_manager.py index 8d998054db..d7e4d472e7 100644 --- a/api/dify_graph/file/file_manager.py +++ b/api/graphon/file/file_manager.py @@ -3,16 +3,15 @@ from __future__ import annotations import base64 from collections.abc import Mapping -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, VideoPromptMessageContent, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType from .runtime import get_workflow_file_runtime @@ -80,7 +79,7 @@ def download(f: File, /) -> bytes: FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE, ): - return _download_file_content(f.storage_key) + return _download_file_content(f) elif f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: raise ValueError("Missing file remote_url") @@ -90,12 +89,9 @@ def download(f: File, /) -> bytes: raise ValueError(f"unsupported transfer method: {f.transfer_method}") -def _download_file_content(path: str, /) -> bytes: +def _download_file_content(file: File, /) -> bytes: """Download and return a file from storage as bytes.""" - data = get_workflow_file_runtime().storage_load(path, stream=False) - if not isinstance(data, bytes): - raise ValueError(f"file {path} is not a bytes object") - return data + return get_workflow_file_runtime().load_file_bytes(file=file) def _get_encoded_string(f: File, /) -> str: @@ -107,30 +103,20 @@ def _get_encoded_string(f: File, /) -> str: response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) return base64.b64encode(data).decode("utf-8") def _to_url(f: File, /): - if f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - return f.remote_url - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - if f.related_id is None: - raise ValueError("Missing file related_id") - return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) - elif f.transfer_method == FileTransferMethod.TOOL_FILE: - if f.related_id is None or f.extension is None: - raise ValueError("Missing file related_id or extension") - return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension) - else: + url = f.generate_url() + if url is None: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + return url class FileManager: diff --git a/api/graphon/file/helpers.py b/api/graphon/file/helpers.py new file mode 100644 index 0000000000..dade761227 --- /dev/null +++ b/api/graphon/file/helpers.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .runtime import get_workflow_file_runtime + +if TYPE_CHECKING: + from .models import File + + +def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: + return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) + + +def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: + return get_workflow_file_runtime().resolve_upload_file_url( + upload_file_id=upload_file_id, + as_attachment=as_attachment, + for_external=for_external, + ) + + +def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: + return get_workflow_file_runtime().resolve_tool_file_url( + tool_file_id=tool_file_id, + extension=extension, + for_external=for_external, + ) + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="image", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="file", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) diff --git a/api/dify_graph/file/models.py b/api/graphon/file/models.py similarity index 61% rename from api/dify_graph/file/models.py rename to api/graphon/file/models.py index dcba00978e..ccd7584371 100644 --- a/api/dify_graph/file/models.py +++ b/api/graphon/file/models.py @@ -1,17 +1,20 @@ from __future__ import annotations +import base64 +import json from collections.abc import Mapping, Sequence from typing import Any -from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from . import helpers from .constants import FILE_MODEL_IDENTITY from .enums import FileTransferMethod, FileType +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" @@ -44,57 +47,68 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 -class ToolFile(BaseModel): - id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") - user_id: UUID = Field(..., description="ID of the user who owns this file") - tenant_id: UUID = Field(..., description="ID of the tenant/organization") - conversation_id: UUID | None = Field(None, description="ID of the associated conversation") - file_key: str = Field(..., max_length=255, description="Storage key for the file") - mimetype: str = Field(..., max_length=255, description="MIME type of the file") - original_url: str | None = Field( - None, max_length=2048, description="Original URL if file was fetched from external source" - ) - name: str = Field(default="", max_length=255, description="Display name of the file") - size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") +def _parse_reference(reference: str | None) -> tuple[str | None, str | None]: + """Best-effort parser for record references and historical storage-key payloads.""" + if not reference: + return None, None - class Config: - from_attributes = True # Enable ORM mode for SQLAlchemy compatibility - populate_by_name = True + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return reference, None + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return reference, None + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return reference, None + + storage_key = payload.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + return record_id, storage_key class File(BaseModel): + """Graph-owned file reference. + + The graph layer deliberately keeps only the metadata required to route, + serialize, and render files. Application ownership concerns such as + tenant/user/conversation identity stay in the workflow/storage layer. + """ + # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY id: str | None = None # message file id - tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. remote_url: str | None = None # remote url - # If `transfer_method` is `FileTransferMethod.local_file` or - # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. - # - # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: str | None = None + # Opaque workflow-layer reference for files resolved outside ``graphon``. + # New payloads only carry the backing record id; historical payloads may + # still include storage_key and must remain readable. + reference: str | None = None filename: str | None = None extension: str | None = Field(default=None, description="File extension, should contain dot") mime_type: str | None = None size: int = -1 - - # Those properties are private, should not be exposed to the outside. _storage_key: str def __init__( self, *, id: str | None = None, - tenant_id: str, + tenant_id: str | None = None, type: FileType, transfer_method: FileTransferMethod, remote_url: str | None = None, + reference: str | None = None, related_id: str | None = None, filename: str | None = None, extension: str | None = None, @@ -103,18 +117,23 @@ class File(BaseModel): storage_key: str | None = None, dify_model_identity: str | None = FILE_MODEL_IDENTITY, url: str | None = None, - # Legacy compatibility fields - explicitly handle known extra fields + # Legacy compatibility fields - explicitly accept known extra fields tool_file_id: str | None = None, upload_file_id: str | None = None, datasource_file_id: str | None = None, ): + legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id + normalized_reference = reference + if normalized_reference is None and legacy_record_id is not None: + normalized_reference = str(legacy_record_id) + _, parsed_storage_key = _parse_reference(normalized_reference) + super().__init__( id=id, - tenant_id=tenant_id, type=type, transfer_method=transfer_method, remote_url=remote_url, - related_id=related_id, + reference=normalized_reference, filename=filename, extension=extension, mime_type=mime_type, @@ -122,12 +141,15 @@ class File(BaseModel): dify_model_identity=dify_model_identity, url=url, ) - self._storage_key = str(storage_key) + # Accept legacy constructor fields without promoting them back into the graph model. + _ = tenant_id + self._storage_key = storage_key or parsed_storage_key or "" def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { **data, + "related_id": self.related_id, "url": self.generate_url(), } @@ -142,21 +164,7 @@ class File(BaseModel): return text def generate_url(self, for_external: bool = True) -> str | None: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) - elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: - assert self.related_id is not None - assert self.extension is not None - return sign_tool_file( - tool_file_id=self.related_id, - extension=self.extension, - for_external=for_external, - ) - return None + return helpers.resolve_file_url(self, for_external=for_external) def to_plugin_parameter(self) -> dict[str, Any]: return { @@ -178,19 +186,29 @@ class File(BaseModel): if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): raise ValueError("Invalid file url") case FileTransferMethod.LOCAL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.TOOL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.DATASOURCE_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") return self + @property + def related_id(self) -> str | None: + record_id, _ = _parse_reference(self.reference) + return record_id + + @related_id.setter + def related_id(self, value: str | None) -> None: + self.reference = value + @property def storage_key(self) -> str: - return self._storage_key + _, storage_key = _parse_reference(self.reference) + return storage_key or self._storage_key @storage_key.setter def storage_key(self, value: str) -> None: diff --git a/api/graphon/file/protocols.py b/api/graphon/file/protocols.py new file mode 100644 index 0000000000..0acabe35e5 --- /dev/null +++ b/api/graphon/file/protocols.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Literal, Protocol + +if TYPE_CHECKING: + from .models import File + + +class HttpResponseProtocol(Protocol): + """Subset of response behavior needed by workflow file helpers.""" + + @property + def content(self) -> bytes: ... + + def raise_for_status(self) -> object: ... + + +class WorkflowFileRuntimeProtocol(Protocol): + """Runtime dependencies required by ``graphon.file``. + + Implementations are expected to be provided by integration layers (for example, + ``core.app.workflow.file_runtime``) so the workflow package avoids importing + application infrastructure modules directly. + """ + + @property + def multimodal_send_format(self) -> str: ... + + def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... + + def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... + + def load_file_bytes(self, *, file: File) -> bytes: ... + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: ... + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: ... diff --git a/api/dify_graph/file/runtime.py b/api/graphon/file/runtime.py similarity index 63% rename from api/dify_graph/file/runtime.py rename to api/graphon/file/runtime.py index 94253e0255..1c5d1c3ca4 100644 --- a/api/dify_graph/file/runtime.py +++ b/api/graphon/file/runtime.py @@ -1,10 +1,13 @@ from __future__ import annotations from collections.abc import Generator -from typing import NoReturn +from typing import TYPE_CHECKING, Literal, NoReturn from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +if TYPE_CHECKING: + from .models import File + class WorkflowFileRuntimeNotConfiguredError(RuntimeError): """Raised when workflow file runtime dependencies were not configured.""" @@ -16,22 +19,6 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" ) - @property - def files_url(self) -> str: - self._raise() - - @property - def internal_files_url(self) -> str | None: - self._raise() - - @property - def secret_key(self) -> str: - self._raise() - - @property - def files_access_timeout(self) -> int: - self._raise() - @property def multimodal_send_format(self) -> str: self._raise() @@ -42,7 +29,33 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: self._raise() - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + self._raise() + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + self._raise() + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._raise() + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._raise() + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: self._raise() diff --git a/api/dify_graph/file/tool_file_parser.py b/api/graphon/file/tool_file_parser.py similarity index 100% rename from api/dify_graph/file/tool_file_parser.py rename to api/graphon/file/tool_file_parser.py diff --git a/api/dify_graph/graph/__init__.py b/api/graphon/graph/__init__.py similarity index 100% rename from api/dify_graph/graph/__init__.py rename to api/graphon/graph/__init__.py diff --git a/api/dify_graph/graph/edge.py b/api/graphon/graph/edge.py similarity index 91% rename from api/dify_graph/graph/edge.py rename to api/graphon/graph/edge.py index f4f67ea6be..1f8a2884e3 100644 --- a/api/dify_graph/graph/edge.py +++ b/api/graphon/graph/edge.py @@ -1,7 +1,7 @@ import uuid from dataclasses import dataclass, field -from dify_graph.enums import NodeState +from graphon.enums import NodeState @dataclass diff --git a/api/dify_graph/graph/graph.py b/api/graphon/graph/graph.py similarity index 98% rename from api/dify_graph/graph/graph.py rename to api/graphon/graph/graph.py index 85117583e0..0f4cd8925f 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/graphon/graph/graph.py @@ -7,10 +7,9 @@ from typing import Protocol, cast, final from pydantic import TypeAdapter -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.nodes.base.node import Node -from libs.typing import is_str +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState +from graphon.nodes.base.node import Node from .edge import Edge from .validation import get_graph_validator @@ -102,7 +101,7 @@ class Graph: source = edge_config.get("source") target = edge_config.get("target") - if not is_str(source) or not is_str(target): + if not isinstance(source, str) or not isinstance(target, str): continue # Create edge @@ -110,7 +109,7 @@ class Graph: edge_counter += 1 source_handle = edge_config.get("sourceHandle", "source") - if not is_str(source_handle): + if not isinstance(source_handle, str): continue edge = Edge( diff --git a/api/dify_graph/graph/graph_template.py b/api/graphon/graph/graph_template.py similarity index 100% rename from api/dify_graph/graph/graph_template.py rename to api/graphon/graph/graph_template.py diff --git a/api/dify_graph/graph/validation.py b/api/graphon/graph/validation.py similarity index 98% rename from api/dify_graph/graph/validation.py rename to api/graphon/graph/validation.py index 50d1440b04..04b501fd33 100644 --- a/api/dify_graph/graph/validation.py +++ b/api/graphon/graph/validation.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType if TYPE_CHECKING: from .graph import Graph diff --git a/api/dify_graph/graph_engine/__init__.py b/api/graphon/graph_engine/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/__init__.py rename to api/graphon/graph_engine/__init__.py diff --git a/api/dify_graph/graph_engine/_engine_utils.py b/api/graphon/graph_engine/_engine_utils.py similarity index 100% rename from api/dify_graph/graph_engine/_engine_utils.py rename to api/graphon/graph_engine/_engine_utils.py diff --git a/api/dify_graph/graph_engine/command_channels/README.md b/api/graphon/graph_engine/command_channels/README.md similarity index 100% rename from api/dify_graph/graph_engine/command_channels/README.md rename to api/graphon/graph_engine/command_channels/README.md diff --git a/api/dify_graph/graph_engine/command_channels/__init__.py b/api/graphon/graph_engine/command_channels/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/command_channels/__init__.py rename to api/graphon/graph_engine/command_channels/__init__.py diff --git a/api/dify_graph/graph_engine/command_channels/in_memory_channel.py b/api/graphon/graph_engine/command_channels/in_memory_channel.py similarity index 100% rename from api/dify_graph/graph_engine/command_channels/in_memory_channel.py rename to api/graphon/graph_engine/command_channels/in_memory_channel.py diff --git a/api/dify_graph/graph_engine/command_channels/redis_channel.py b/api/graphon/graph_engine/command_channels/redis_channel.py similarity index 100% rename from api/dify_graph/graph_engine/command_channels/redis_channel.py rename to api/graphon/graph_engine/command_channels/redis_channel.py diff --git a/api/dify_graph/graph_engine/command_processing/__init__.py b/api/graphon/graph_engine/command_processing/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/command_processing/__init__.py rename to api/graphon/graph_engine/command_processing/__init__.py diff --git a/api/dify_graph/graph_engine/command_processing/command_handlers.py b/api/graphon/graph_engine/command_processing/command_handlers.py similarity index 95% rename from api/dify_graph/graph_engine/command_processing/command_handlers.py rename to api/graphon/graph_engine/command_processing/command_handlers.py index eefd0c366b..ad92fd1abb 100644 --- a/api/dify_graph/graph_engine/command_processing/command_handlers.py +++ b/api/graphon/graph_engine/command_processing/command_handlers.py @@ -3,8 +3,8 @@ from typing import final from typing_extensions import override -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.runtime import VariablePool +from graphon.entities.pause_reason import SchedulingPause +from graphon.runtime import VariablePool from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand diff --git a/api/dify_graph/graph_engine/command_processing/command_processor.py b/api/graphon/graph_engine/command_processing/command_processor.py similarity index 100% rename from api/dify_graph/graph_engine/command_processing/command_processor.py rename to api/graphon/graph_engine/command_processing/command_processor.py diff --git a/api/dify_graph/graph_engine/config.py b/api/graphon/graph_engine/config.py similarity index 100% rename from api/dify_graph/graph_engine/config.py rename to api/graphon/graph_engine/config.py diff --git a/api/dify_graph/graph_engine/domain/__init__.py b/api/graphon/graph_engine/domain/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/domain/__init__.py rename to api/graphon/graph_engine/domain/__init__.py diff --git a/api/dify_graph/graph_engine/domain/graph_execution.py b/api/graphon/graph_engine/domain/graph_execution.py similarity index 97% rename from api/dify_graph/graph_engine/domain/graph_execution.py rename to api/graphon/graph_engine/domain/graph_execution.py index 0ee4a9f9a7..9c0c7d1624 100644 --- a/api/dify_graph/graph_engine/domain/graph_execution.py +++ b/api/graphon/graph_engine/domain/graph_execution.py @@ -8,9 +8,9 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import NodeState -from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeState +from graphon.runtime.graph_runtime_state import GraphExecutionProtocol from .node_execution import NodeExecution diff --git a/api/dify_graph/graph_engine/domain/node_execution.py b/api/graphon/graph_engine/domain/node_execution.py similarity index 96% rename from api/dify_graph/graph_engine/domain/node_execution.py rename to api/graphon/graph_engine/domain/node_execution.py index ae8f9a5e50..dafd6ccd8a 100644 --- a/api/dify_graph/graph_engine/domain/node_execution.py +++ b/api/graphon/graph_engine/domain/node_execution.py @@ -4,7 +4,7 @@ NodeExecution entity representing a node's execution state. from dataclasses import dataclass -from dify_graph.enums import NodeState +from graphon.enums import NodeState @dataclass diff --git a/api/dify_graph/model_runtime/callbacks/__init__.py b/api/graphon/graph_engine/entities/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/callbacks/__init__.py rename to api/graphon/graph_engine/entities/__init__.py diff --git a/api/dify_graph/graph_engine/entities/commands.py b/api/graphon/graph_engine/entities/commands.py similarity index 97% rename from api/dify_graph/graph_engine/entities/commands.py rename to api/graphon/graph_engine/entities/commands.py index c56845cfc4..25ebc804b6 100644 --- a/api/dify_graph/graph_engine/entities/commands.py +++ b/api/graphon/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.variables.variables import Variable +from graphon.variables.variables import Variable class CommandType(StrEnum): diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/graphon/graph_engine/error_handler.py similarity index 97% rename from api/dify_graph/graph_engine/error_handler.py rename to api/graphon/graph_engine/error_handler.py index e206f21592..43ce8bb502 100644 --- a/api/dify_graph/graph_engine/error_handler.py +++ b/api/graphon/graph_engine/error_handler.py @@ -6,21 +6,21 @@ import logging import time from typing import TYPE_CHECKING, final -from dify_graph.enums import ( +from graphon.enums import ( ErrorStrategy as ErrorStrategyEnum, ) -from dify_graph.enums import ( +from graphon.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from graphon.graph import Graph +from graphon.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetryEvent, ) -from dify_graph.node_events import NodeRunResult +from graphon.node_events import NodeRunResult if TYPE_CHECKING: from .domain import GraphExecution diff --git a/api/dify_graph/graph_engine/event_management/__init__.py b/api/graphon/graph_engine/event_management/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/event_management/__init__.py rename to api/graphon/graph_engine/event_management/__init__.py diff --git a/api/dify_graph/graph_engine/event_management/event_handlers.py b/api/graphon/graph_engine/event_management/event_handlers.py similarity index 93% rename from api/dify_graph/graph_engine/event_management/event_handlers.py rename to api/graphon/graph_engine/event_management/event_handlers.py index 7f5ad40e0e..184148280d 100644 --- a/api/dify_graph/graph_engine/event_management/event_handlers.py +++ b/api/graphon/graph_engine/event_management/event_handlers.py @@ -7,9 +7,9 @@ from collections.abc import Mapping from functools import singledispatchmethod from typing import TYPE_CHECKING, final -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState +from graphon.graph import Graph +from graphon.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunExceptionEvent, @@ -28,9 +28,10 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState from ..domain.graph_execution import GraphExecution from ..response_coordinator import ResponseStreamCoordinator @@ -93,6 +94,10 @@ class EventHandler: Args: event: The event to handle """ + if isinstance(event, NodeRunVariableUpdatedEvent): + self._dispatch(event) + return + # Events in loops or iterations are always collected if event.in_loop_id or event.in_iteration_id: self._event_collector.collect(event) @@ -153,6 +158,17 @@ class EventHandler: for stream_event in streaming_events: self._event_collector.collect(stream_event) + @_dispatch.register + def _(self, event: NodeRunVariableUpdatedEvent) -> None: + """ + Apply a node-requested variable mutation before downstream observers run. + + The event is collected like other node events so parent/container engines can + forward the updated payload to outer layers, including persistence listeners. + """ + self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) + self._event_collector.collect(event) + @_dispatch.register def _(self, event: NodeRunSucceededEvent) -> None: """ diff --git a/api/dify_graph/graph_engine/event_management/event_manager.py b/api/graphon/graph_engine/event_management/event_manager.py similarity index 99% rename from api/dify_graph/graph_engine/event_management/event_manager.py rename to api/graphon/graph_engine/event_management/event_manager.py index 616f621c3e..5b2fb365e9 100644 --- a/api/dify_graph/graph_engine/event_management/event_manager.py +++ b/api/graphon/graph_engine/event_management/event_manager.py @@ -9,7 +9,7 @@ from collections.abc import Generator from contextlib import contextmanager from typing import final -from dify_graph.graph_events import GraphEngineEvent +from graphon.graph_events import GraphEngineEvent from ..layers.base import GraphEngineLayer diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/graphon/graph_engine/graph_engine.py similarity index 91% rename from api/dify_graph/graph_engine/graph_engine.py rename to api/graphon/graph_engine/graph_engine.py index ea98a46b06..32e0e60502 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/graphon/graph_engine/graph_engine.py @@ -9,14 +9,13 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import TYPE_CHECKING, cast, final -from dify_graph.context import capture_current_context -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import NodeExecutionType +from graphon.graph import Graph +from graphon.graph_events import ( GraphEngineEvent, GraphNodeEventBase, GraphRunAbortedEvent, @@ -26,11 +25,11 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from dify_graph.runtime.graph_runtime_state import GraphProtocol + from graphon.runtime.graph_runtime_state import GraphProtocol from .command_processing import ( AbortCommandHandler, @@ -50,9 +49,9 @@ from .protocols.command_channel import CommandChannel from .worker_management import WorkerPool if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.graph_engine.domain.graph_execution import GraphExecution - from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator + from graphon.entities import GraphInitParams + from graphon.graph_engine.domain.graph_execution import GraphExecution + from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator logger = logging.getLogger(__name__) @@ -86,6 +85,7 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._layers: list[GraphEngineLayer] = [] self._child_engine_builder = child_engine_builder if child_engine_builder is not None: self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) @@ -149,21 +149,14 @@ class GraphEngine: update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Worker Pool Setup === - # Capture execution context for worker threads - execution_context = capture_current_context() - # Create worker pool for parallel node execution self._worker_pool = WorkerPool( ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, layers=self._layers, - execution_context=execution_context, + execution_context=self._graph_runtime_state.execution_context, config=self._config, ) @@ -220,23 +213,23 @@ class GraphEngine: self._bind_layer_context(layer) return self + def request_abort(self, reason: str | None = None) -> None: + """Queue an abort command for this engine.""" + self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) + def create_child_engine( self, *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: dict[str, object] | Mapping[str, object], root_node_id: str, - layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: return self._graph_runtime_state.create_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) def run(self) -> Generator[GraphEngineEvent, None, None]: diff --git a/api/dify_graph/graph_engine/graph_state_manager.py b/api/graphon/graph_engine/graph_state_manager.py similarity index 99% rename from api/dify_graph/graph_engine/graph_state_manager.py rename to api/graphon/graph_engine/graph_state_manager.py index 922a968435..ade8e403a8 100644 --- a/api/dify_graph/graph_engine/graph_state_manager.py +++ b/api/graphon/graph_engine/graph_state_manager.py @@ -6,8 +6,8 @@ import threading from collections.abc import Sequence from typing import TypedDict, final -from dify_graph.enums import NodeState -from dify_graph.graph import Edge, Graph +from graphon.enums import NodeState +from graphon.graph import Edge, Graph from .ready_queue import ReadyQueue diff --git a/api/dify_graph/graph_engine/graph_traversal/__init__.py b/api/graphon/graph_engine/graph_traversal/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/graph_traversal/__init__.py rename to api/graphon/graph_engine/graph_traversal/__init__.py diff --git a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py b/api/graphon/graph_engine/graph_traversal/edge_processor.py similarity index 97% rename from api/dify_graph/graph_engine/graph_traversal/edge_processor.py rename to api/graphon/graph_engine/graph_traversal/edge_processor.py index c4625a8ff7..e51eee8a69 100644 --- a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py +++ b/api/graphon/graph_engine/graph_traversal/edge_processor.py @@ -5,9 +5,9 @@ Edge processing logic for graph traversal. from collections.abc import Sequence from typing import TYPE_CHECKING, final -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Edge, Graph -from dify_graph.graph_events import NodeRunStreamChunkEvent +from graphon.enums import NodeExecutionType +from graphon.graph import Edge, Graph +from graphon.graph_events import NodeRunStreamChunkEvent from ..graph_state_manager import GraphStateManager from ..response_coordinator import ResponseStreamCoordinator diff --git a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py b/api/graphon/graph_engine/graph_traversal/skip_propagator.py similarity index 98% rename from api/dify_graph/graph_engine/graph_traversal/skip_propagator.py rename to api/graphon/graph_engine/graph_traversal/skip_propagator.py index 76445bccd2..bdb83b38ad 100644 --- a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py +++ b/api/graphon/graph_engine/graph_traversal/skip_propagator.py @@ -5,7 +5,7 @@ Skip state propagation through the graph. from collections.abc import Sequence from typing import final -from dify_graph.graph import Edge, Graph +from graphon.graph import Edge, Graph from ..graph_state_manager import GraphStateManager diff --git a/api/dify_graph/graph_engine/layers/README.md b/api/graphon/graph_engine/layers/README.md similarity index 100% rename from api/dify_graph/graph_engine/layers/README.md rename to api/graphon/graph_engine/layers/README.md diff --git a/api/dify_graph/graph_engine/layers/__init__.py b/api/graphon/graph_engine/layers/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/layers/__init__.py rename to api/graphon/graph_engine/layers/__init__.py diff --git a/api/dify_graph/graph_engine/layers/base.py b/api/graphon/graph_engine/layers/base.py similarity index 94% rename from api/dify_graph/graph_engine/layers/base.py rename to api/graphon/graph_engine/layers/base.py index 890336c1ca..605615d347 100644 --- a/api/dify_graph/graph_engine/layers/base.py +++ b/api/graphon/graph_engine/layers/base.py @@ -7,10 +7,10 @@ intercept and respond to GraphEngine events. from abc import ABC, abstractmethod -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ReadOnlyGraphRuntimeState +from graphon.graph_engine.protocols.command_channel import CommandChannel +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.runtime import ReadOnlyGraphRuntimeState class GraphEngineLayerNotInitializedError(Exception): diff --git a/api/dify_graph/graph_engine/layers/debug_logging.py b/api/graphon/graph_engine/layers/debug_logging.py similarity index 99% rename from api/dify_graph/graph_engine/layers/debug_logging.py rename to api/graphon/graph_engine/layers/debug_logging.py index 1af2e2db9e..e6585fb3b9 100644 --- a/api/dify_graph/graph_engine/layers/debug_logging.py +++ b/api/graphon/graph_engine/layers/debug_logging.py @@ -11,7 +11,7 @@ from typing import Any, final from typing_extensions import override -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, diff --git a/api/dify_graph/graph_engine/layers/execution_limits.py b/api/graphon/graph_engine/layers/execution_limits.py similarity index 94% rename from api/dify_graph/graph_engine/layers/execution_limits.py rename to api/graphon/graph_engine/layers/execution_limits.py index 48ba5608d9..2742b3acd3 100644 --- a/api/dify_graph/graph_engine/layers/execution_limits.py +++ b/api/graphon/graph_engine/layers/execution_limits.py @@ -15,13 +15,13 @@ from typing import final from typing_extensions import override -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, NodeRunStartedEvent, ) -from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent class LimitType(StrEnum): diff --git a/api/dify_graph/graph_engine/manager.py b/api/graphon/graph_engine/manager.py similarity index 94% rename from api/dify_graph/graph_engine/manager.py rename to api/graphon/graph_engine/manager.py index 955c149069..c728ff6986 100644 --- a/api/dify_graph/graph_engine/manager.py +++ b/api/graphon/graph_engine/manager.py @@ -10,8 +10,8 @@ import logging from collections.abc import Sequence from typing import final -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from dify_graph.graph_engine.entities.commands import ( +from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol +from graphon.graph_engine.entities.commands import ( AbortCommand, GraphEngineCommand, PauseCommand, diff --git a/api/dify_graph/graph_engine/orchestration/__init__.py b/api/graphon/graph_engine/orchestration/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/orchestration/__init__.py rename to api/graphon/graph_engine/orchestration/__init__.py diff --git a/api/dify_graph/graph_engine/orchestration/dispatcher.py b/api/graphon/graph_engine/orchestration/dispatcher.py similarity index 99% rename from api/dify_graph/graph_engine/orchestration/dispatcher.py rename to api/graphon/graph_engine/orchestration/dispatcher.py index f8aaf20b2f..f75bbee08e 100644 --- a/api/dify_graph/graph_engine/orchestration/dispatcher.py +++ b/api/graphon/graph_engine/orchestration/dispatcher.py @@ -8,7 +8,7 @@ import threading import time from typing import TYPE_CHECKING, final -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, diff --git a/api/dify_graph/graph_engine/orchestration/execution_coordinator.py b/api/graphon/graph_engine/orchestration/execution_coordinator.py similarity index 100% rename from api/dify_graph/graph_engine/orchestration/execution_coordinator.py rename to api/graphon/graph_engine/orchestration/execution_coordinator.py diff --git a/api/dify_graph/graph_engine/protocols/command_channel.py b/api/graphon/graph_engine/protocols/command_channel.py similarity index 100% rename from api/dify_graph/graph_engine/protocols/command_channel.py rename to api/graphon/graph_engine/protocols/command_channel.py diff --git a/api/dify_graph/graph_engine/ready_queue/__init__.py b/api/graphon/graph_engine/ready_queue/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/__init__.py rename to api/graphon/graph_engine/ready_queue/__init__.py diff --git a/api/dify_graph/graph_engine/ready_queue/factory.py b/api/graphon/graph_engine/ready_queue/factory.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/factory.py rename to api/graphon/graph_engine/ready_queue/factory.py diff --git a/api/dify_graph/graph_engine/ready_queue/in_memory.py b/api/graphon/graph_engine/ready_queue/in_memory.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/in_memory.py rename to api/graphon/graph_engine/ready_queue/in_memory.py diff --git a/api/dify_graph/graph_engine/ready_queue/protocol.py b/api/graphon/graph_engine/ready_queue/protocol.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/protocol.py rename to api/graphon/graph_engine/ready_queue/protocol.py diff --git a/api/dify_graph/graph_engine/response_coordinator/__init__.py b/api/graphon/graph_engine/response_coordinator/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/response_coordinator/__init__.py rename to api/graphon/graph_engine/response_coordinator/__init__.py diff --git a/api/dify_graph/graph_engine/response_coordinator/coordinator.py b/api/graphon/graph_engine/response_coordinator/coordinator.py similarity index 98% rename from api/dify_graph/graph_engine/response_coordinator/coordinator.py rename to api/graphon/graph_engine/response_coordinator/coordinator.py index 941a8a496b..a6562f0223 100644 --- a/api/dify_graph/graph_engine/response_coordinator/coordinator.py +++ b/api/graphon/graph_engine/response_coordinator/coordinator.py @@ -14,11 +14,11 @@ from uuid import uuid4 from pydantic import BaseModel, Field -from dify_graph.enums import NodeExecutionType, NodeState -from dify_graph.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from dify_graph.nodes.base.template import TextSegment, VariableSegment -from dify_graph.runtime import VariablePool -from dify_graph.runtime.graph_runtime_state import GraphProtocol +from graphon.enums import NodeExecutionType, NodeState +from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent +from graphon.nodes.base.template import TextSegment, VariableSegment +from graphon.runtime import VariablePool +from graphon.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession diff --git a/api/dify_graph/graph_engine/response_coordinator/path.py b/api/graphon/graph_engine/response_coordinator/path.py similarity index 100% rename from api/dify_graph/graph_engine/response_coordinator/path.py rename to api/graphon/graph_engine/response_coordinator/path.py diff --git a/api/dify_graph/graph_engine/response_coordinator/session.py b/api/graphon/graph_engine/response_coordinator/session.py similarity index 94% rename from api/dify_graph/graph_engine/response_coordinator/session.py rename to api/graphon/graph_engine/response_coordinator/session.py index 11a9f5dac5..cb877f1504 100644 --- a/api/dify_graph/graph_engine/response_coordinator/session.py +++ b/api/graphon/graph_engine/response_coordinator/session.py @@ -10,8 +10,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Protocol, cast -from dify_graph.nodes.base.template import Template -from dify_graph.runtime.graph_runtime_state import NodeProtocol +from graphon.nodes.base.template import Template +from graphon.runtime.graph_runtime_state import NodeProtocol class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): diff --git a/api/dify_graph/graph_engine/worker.py b/api/graphon/graph_engine/worker.py similarity index 92% rename from api/dify_graph/graph_engine/worker.py rename to api/graphon/graph_engine/worker.py index 988c20d72a..a0844ee48e 100644 --- a/api/dify_graph/graph_engine/worker.py +++ b/api/graphon/graph_engine/worker.py @@ -9,19 +9,18 @@ import queue import threading import time from collections.abc import Sequence -from datetime import datetime +from contextlib import AbstractContextManager +from datetime import UTC, datetime from typing import TYPE_CHECKING, final from typing_extensions import override -from dify_graph.context import IExecutionContext -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from libs.datetime_utils import naive_utc_now +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .ready_queue import ReadyQueue @@ -46,7 +45,7 @@ class Worker(threading.Thread): graph: Graph, layers: Sequence[GraphEngineLayer], worker_id: int = 0, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize worker thread. @@ -187,7 +186,7 @@ class Worker(threading.Thread): self, node: Node, error: Exception, *, started_at: datetime | None = None ) -> NodeRunFailedEvent: """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = naive_utc_now() + failure_time = datetime.now(UTC).replace(tzinfo=None) error_message = str(error) return NodeRunFailedEvent( id=node.execution_id, diff --git a/api/dify_graph/graph_engine/worker_management/__init__.py b/api/graphon/graph_engine/worker_management/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/worker_management/__init__.py rename to api/graphon/graph_engine/worker_management/__init__.py diff --git a/api/dify_graph/graph_engine/worker_management/worker_pool.py b/api/graphon/graph_engine/worker_management/worker_pool.py similarity index 97% rename from api/dify_graph/graph_engine/worker_management/worker_pool.py rename to api/graphon/graph_engine/worker_management/worker_pool.py index cc93087783..85cdf1ca21 100644 --- a/api/dify_graph/graph_engine/worker_management/worker_pool.py +++ b/api/graphon/graph_engine/worker_management/worker_pool.py @@ -8,11 +8,11 @@ DynamicScaler, and WorkerFactory into a single class. import logging import queue import threading +from contextlib import AbstractContextManager from typing import final -from dify_graph.context import IExecutionContext -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphNodeEventBase +from graphon.graph import Graph +from graphon.graph_events import GraphNodeEventBase from ..config import GraphEngineConfig from ..layers.base import GraphEngineLayer @@ -38,7 +38,7 @@ class WorkerPool: graph: Graph, layers: list[GraphEngineLayer], config: GraphEngineConfig, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize the simple worker pool. diff --git a/api/dify_graph/graph_events/__init__.py b/api/graphon/graph_events/__init__.py similarity index 96% rename from api/dify_graph/graph_events/__init__.py rename to api/graphon/graph_events/__init__.py index 56ea642092..7cec587a05 100644 --- a/api/dify_graph/graph_events/__init__.py +++ b/api/graphon/graph_events/__init__.py @@ -46,6 +46,7 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, is_node_result_event, ) @@ -78,5 +79,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "NodeRunVariableUpdatedEvent", "is_node_result_event", ] diff --git a/api/dify_graph/graph_events/agent.py b/api/graphon/graph_events/agent.py similarity index 100% rename from api/dify_graph/graph_events/agent.py rename to api/graphon/graph_events/agent.py diff --git a/api/dify_graph/graph_events/base.py b/api/graphon/graph_events/base.py similarity index 88% rename from api/dify_graph/graph_events/base.py rename to api/graphon/graph_events/base.py index 4560cf5085..4ea9787b9a 100644 --- a/api/dify_graph/graph_events/base.py +++ b/api/graphon/graph_events/base.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field -from dify_graph.enums import NodeType -from dify_graph.node_events import NodeRunResult +from graphon.enums import NodeType +from graphon.node_events import NodeRunResult class GraphEngineEvent(BaseModel): diff --git a/api/dify_graph/graph_events/graph.py b/api/graphon/graph_events/graph.py similarity index 90% rename from api/dify_graph/graph_events/graph.py rename to api/graphon/graph_events/graph.py index f4aaba64d6..3782cb49bc 100644 --- a/api/dify_graph/graph_events/graph.py +++ b/api/graphon/graph_events/graph.py @@ -1,8 +1,8 @@ from pydantic import Field -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events import BaseGraphEvent +from graphon.entities.pause_reason import PauseReason +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): diff --git a/api/dify_graph/graph_events/human_input.py b/api/graphon/graph_events/human_input.py similarity index 100% rename from api/dify_graph/graph_events/human_input.py rename to api/graphon/graph_events/human_input.py diff --git a/api/dify_graph/graph_events/iteration.py b/api/graphon/graph_events/iteration.py similarity index 100% rename from api/dify_graph/graph_events/iteration.py rename to api/graphon/graph_events/iteration.py diff --git a/api/dify_graph/graph_events/loop.py b/api/graphon/graph_events/loop.py similarity index 100% rename from api/dify_graph/graph_events/loop.py rename to api/graphon/graph_events/loop.py diff --git a/api/dify_graph/graph_events/node.py b/api/graphon/graph_events/node.py similarity index 86% rename from api/dify_graph/graph_events/node.py rename to api/graphon/graph_events/node.py index df19d6c03b..471ae08ee7 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/graphon/graph_events/node.py @@ -1,10 +1,11 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from pydantic import Field -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason +from graphon.entities.pause_reason import PauseReason +from graphon.variables.variables import Variable from .base import GraphNodeEventBase @@ -30,7 +31,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase): class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") + retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") context: str = Field(..., description="context") @@ -39,6 +40,12 @@ class NodeRunSucceededEvent(GraphNodeEventBase): finished_at: datetime | None = Field(default=None, description="node finish time") +class NodeRunVariableUpdatedEvent(GraphNodeEventBase): + """Request that the engine apply a variable update before downstream observers continue.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class NodeRunFailedEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") diff --git a/api/dify_graph/model_runtime/README.md b/api/graphon/model_runtime/README.md similarity index 100% rename from api/dify_graph/model_runtime/README.md rename to api/graphon/model_runtime/README.md diff --git a/api/dify_graph/model_runtime/README_CN.md b/api/graphon/model_runtime/README_CN.md similarity index 100% rename from api/dify_graph/model_runtime/README_CN.md rename to api/graphon/model_runtime/README_CN.md diff --git a/api/dify_graph/model_runtime/errors/__init__.py b/api/graphon/model_runtime/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/errors/__init__.py rename to api/graphon/model_runtime/__init__.py diff --git a/api/dify_graph/model_runtime/model_providers/__base/__init__.py b/api/graphon/model_runtime/callbacks/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__base/__init__.py rename to api/graphon/model_runtime/callbacks/__init__.py diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/graphon/model_runtime/callbacks/base_callback.py similarity index 78% rename from api/dify_graph/model_runtime/callbacks/base_callback.py rename to api/graphon/model_runtime/callbacks/base_callback.py index 20faf3d6cd..cd85cf6301 100644 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ b/api/graphon/model_runtime/callbacks/base_callback.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.model_providers.__base.ai_model import AIModel _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -34,6 +34,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -46,7 +47,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -63,6 +65,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -76,7 +79,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -93,6 +97,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -106,7 +111,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -123,6 +129,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -136,7 +143,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/graphon/model_runtime/callbacks/logging_callback.py similarity index 81% rename from api/dify_graph/model_runtime/callbacks/logging_callback.py rename to api/graphon/model_runtime/callbacks/logging_callback.py index 49b9ab27eb..f96eb446fc 100644 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ b/api/graphon/model_runtime/callbacks/logging_callback.py @@ -1,13 +1,13 @@ import json import logging import sys -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import cast -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -24,6 +24,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -36,7 +37,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ self.print_text("\n[on_llm_before_invoke]\n", color="blue") self.print_text(f"Model: {model}\n", color="blue") @@ -53,10 +55,12 @@ class LoggingCallback(Callback): self.print_text(f"\t\t{tool.name}\n", color="blue") self.print_text(f"Stream: {stream}\n", color="blue") - if user: self.print_text(f"User: {user}\n", color="blue") + if invocation_context: + self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: @@ -80,6 +84,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -93,8 +98,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() @@ -110,6 +116,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -123,8 +130,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context self.print_text("\n[on_llm_after_invoke]\n", color="yellow") self.print_text(f"Content: {result.message.content}\n", color="yellow") @@ -151,6 +159,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -164,7 +173,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/dify_graph/model_runtime/entities/__init__.py b/api/graphon/model_runtime/entities/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/entities/__init__.py rename to api/graphon/model_runtime/entities/__init__.py diff --git a/api/dify_graph/model_runtime/entities/common_entities.py b/api/graphon/model_runtime/entities/common_entities.py similarity index 100% rename from api/dify_graph/model_runtime/entities/common_entities.py rename to api/graphon/model_runtime/entities/common_entities.py diff --git a/api/dify_graph/model_runtime/entities/defaults.py b/api/graphon/model_runtime/entities/defaults.py similarity index 98% rename from api/dify_graph/model_runtime/entities/defaults.py rename to api/graphon/model_runtime/entities/defaults.py index 53b732e5c6..bcce17c5d5 100644 --- a/api/dify_graph/model_runtime/entities/defaults.py +++ b/api/graphon/model_runtime/entities/defaults.py @@ -1,4 +1,4 @@ -from dify_graph.model_runtime.entities.model_entities import DefaultParameterName +from graphon.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { diff --git a/api/dify_graph/model_runtime/entities/llm_entities.py b/api/graphon/model_runtime/entities/llm_entities.py similarity index 97% rename from api/dify_graph/model_runtime/entities/llm_entities.py rename to api/graphon/model_runtime/entities/llm_entities.py index eec682a2ae..bfc80f21c5 100644 --- a/api/dify_graph/model_runtime/entities/llm_entities.py +++ b/api/graphon/model_runtime/entities/llm_entities.py @@ -7,8 +7,8 @@ from typing import Any, TypedDict, Union from pydantic import BaseModel, Field -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage +from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo class LLMMode(StrEnum): diff --git a/api/dify_graph/model_runtime/entities/message_entities.py b/api/graphon/model_runtime/entities/message_entities.py similarity index 100% rename from api/dify_graph/model_runtime/entities/message_entities.py rename to api/graphon/model_runtime/entities/message_entities.py diff --git a/api/dify_graph/model_runtime/entities/model_entities.py b/api/graphon/model_runtime/entities/model_entities.py similarity index 98% rename from api/dify_graph/model_runtime/entities/model_entities.py rename to api/graphon/model_runtime/entities/model_entities.py index fbcde6740a..5ec4970faf 100644 --- a/api/dify_graph/model_runtime/entities/model_entities.py +++ b/api/graphon/model_runtime/entities/model_entities.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, model_validator -from dify_graph.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.common_entities import I18nObject class ModelType(StrEnum): diff --git a/api/dify_graph/model_runtime/entities/provider_entities.py b/api/graphon/model_runtime/entities/provider_entities.py similarity index 84% rename from api/dify_graph/model_runtime/entities/provider_entities.py rename to api/graphon/model_runtime/entities/provider_entities.py index 97a99ea7ce..8e6c516fb9 100644 --- a/api/dify_graph/model_runtime/entities/provider_entities.py +++ b/api/graphon/model_runtime/entities/provider_entities.py @@ -3,8 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType class ConfigurateMethod(StrEnum): @@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel): class SimpleProviderEntity(BaseModel): """ - Simple model class for provider. + Simplified provider schema exposed to callers. + + `provider` is the canonical runtime identifier. `provider_name` is an optional + compatibility alias for short-name lookups and is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None @@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel): class ProviderEntity(BaseModel): """ - Model class for provider. + Runtime-native provider schema. + + `provider` is the canonical runtime identifier. `provider_name` is a + compatibility alias for callers that still resolve providers by short name and + is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None @@ -153,6 +162,7 @@ class ProviderEntity(BaseModel): """ return SimpleProviderEntity( provider=self.provider, + provider_name=self.provider_name, label=self.label, icon_small=self.icon_small, supported_model_types=self.supported_model_types, diff --git a/api/dify_graph/model_runtime/entities/rerank_entities.py b/api/graphon/model_runtime/entities/rerank_entities.py similarity index 72% rename from api/dify_graph/model_runtime/entities/rerank_entities.py rename to api/graphon/model_runtime/entities/rerank_entities.py index 99709e1bcd..8a0bb5fac2 100644 --- a/api/dify_graph/model_runtime/entities/rerank_entities.py +++ b/api/graphon/model_runtime/entities/rerank_entities.py @@ -1,6 +1,13 @@ +from typing import TypedDict + from pydantic import BaseModel +class MultimodalRerankInput(TypedDict): + content: str + content_type: str + + class RerankDocument(BaseModel): """ Model class for rerank document. diff --git a/api/dify_graph/model_runtime/entities/text_embedding_entities.py b/api/graphon/model_runtime/entities/text_embedding_entities.py similarity index 71% rename from api/dify_graph/model_runtime/entities/text_embedding_entities.py rename to api/graphon/model_runtime/entities/text_embedding_entities.py index a0210c169d..08ffd83b5b 100644 --- a/api/dify_graph/model_runtime/entities/text_embedding_entities.py +++ b/api/graphon/model_runtime/entities/text_embedding_entities.py @@ -1,8 +1,16 @@ from decimal import Decimal +from enum import StrEnum, auto from pydantic import BaseModel -from dify_graph.model_runtime.entities.model_entities import ModelUsage +from graphon.model_runtime.entities.model_entities import ModelUsage + + +class EmbeddingInputType(StrEnum): + """Embedding request input variants understood by the model runtime.""" + + DOCUMENT = auto() + QUERY = auto() class EmbeddingUsage(ModelUsage): diff --git a/api/dify_graph/model_runtime/model_providers/__init__.py b/api/graphon/model_runtime/errors/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__init__.py rename to api/graphon/model_runtime/errors/__init__.py diff --git a/api/dify_graph/model_runtime/errors/invoke.py b/api/graphon/model_runtime/errors/invoke.py similarity index 100% rename from api/dify_graph/model_runtime/errors/invoke.py rename to api/graphon/model_runtime/errors/invoke.py diff --git a/api/dify_graph/model_runtime/errors/validate.py b/api/graphon/model_runtime/errors/validate.py similarity index 100% rename from api/dify_graph/model_runtime/errors/validate.py rename to api/graphon/model_runtime/errors/validate.py diff --git a/api/dify_graph/model_runtime/memory/__init__.py b/api/graphon/model_runtime/memory/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/memory/__init__.py rename to api/graphon/model_runtime/memory/__init__.py diff --git a/api/dify_graph/model_runtime/memory/prompt_message_memory.py b/api/graphon/model_runtime/memory/prompt_message_memory.py similarity index 89% rename from api/dify_graph/model_runtime/memory/prompt_message_memory.py rename to api/graphon/model_runtime/memory/prompt_message_memory.py index a76a7faf71..03e26e9ff5 100644 --- a/api/dify_graph/model_runtime/memory/prompt_message_memory.py +++ b/api/graphon/model_runtime/memory/prompt_message_memory.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Sequence from typing import Protocol -from dify_graph.model_runtime.entities import PromptMessage +from graphon.model_runtime.entities import PromptMessage DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 diff --git a/api/dify_graph/model_runtime/schema_validators/__init__.py b/api/graphon/model_runtime/model_providers/__base/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/schema_validators/__init__.py rename to api/graphon/model_runtime/model_providers/__base/__init__.py diff --git a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py b/api/graphon/model_runtime/model_providers/__base/ai_model.py similarity index 64% rename from api/dify_graph/model_runtime/model_providers/__base/ai_model.py rename to api/graphon/model_runtime/model_providers/__base/ai_model.py index ac7ae9925b..1700ec9740 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py +++ b/api/graphon/model_runtime/model_providers/__base/ai_model.py @@ -1,15 +1,8 @@ import decimal -import hashlib -import logging -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from redis import RedisError - -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from graphon.model_runtime.entities.model_entities import ( AIModelEntity, DefaultParameterName, ModelType, @@ -17,7 +10,8 @@ from dify_graph.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from extensions.ext_redis import redis_client - -logger = logging.getLogger(__name__) +from graphon.model_runtime.runtime import ModelRuntime -class AIModel(BaseModel): +class AIModel: """ - Base class for all models. + Runtime-facing base class for all model providers. + + This stays a regular Python class because instances hold live collaborators + such as the provider schema and runtime adapter rather than user input that + benefits from Pydantic validation. Subclasses must pin ``model_type`` via a + class attribute; the base class is not meant to be instantiated directly. """ - tenant_id: str = Field(description="Tenant ID") - model_type: ModelType = Field(description="Model type") - plugin_id: str = Field(description="Plugin ID") - provider_name: str = Field(description="Provider") - plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") - started_at: float = Field(description="Invoke start time", default=0) + model_type: ModelType + provider_schema: ProviderEntity + model_runtime: ModelRuntime + started_at: float - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) + def __init__( + self, + provider_schema: ProviderEntity, + model_runtime: ModelRuntime, + *, + started_at: float = 0, + ) -> None: + if getattr(type(self), "model_type", None) is None: + raise TypeError("AIModel subclasses must define model_type as a class attribute") + + self.model_type = type(self).model_type + self.provider_schema = provider_schema + self.model_runtime = model_runtime + self.started_at = started_at + + @property + def provider(self) -> str: + return self.provider_schema.provider + + @property + def provider_display_name(self) -> str: + return self.provider_schema.label.en_US @property def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. + Map model invoke error to unified error. - :return: Invoke error mapping + The key is the error type thrown to the caller, and the value contains + runtime-facing exception types that should be normalized to it. """ - from core.plugin.entities.plugin_daemon import PluginDaemonInnerError - return { InvokeConnectionError: [InvokeConnectionError], InvokeServerUnavailableError: [InvokeServerUnavailableError], InvokeRateLimitError: [InvokeRateLimitError], InvokeAuthorizationError: [InvokeAuthorizationError], InvokeBadRequestError: [InvokeBadRequestError], - PluginDaemonInnerError: [PluginDaemonInnerError], ValueError: [ValueError], } @@ -79,15 +89,18 @@ class AIModel(BaseModel): if invoke_error == InvokeAuthorizationError: return InvokeAuthorizationError( description=( - f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." + f"[{self.provider_display_name}] Incorrect model credentials provided, " + "please check and try again." ) ) elif isinstance(invoke_error, InvokeError): - return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + return InvokeError( + description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" + ) else: return error - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") + return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: """ @@ -144,65 +157,13 @@ class AIModel(BaseModel): :param credentials: model credentials :return: model schema """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, + return self.model_runtime.get_model_schema( + provider=self.provider, + model_type=self.model_type, model=model, credentials=credentials or {}, ) - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/graphon/model_runtime/model_providers/__base/large_language_model.py similarity index 88% rename from api/dify_graph/model_runtime/model_providers/__base/large_language_model.py rename to api/graphon/model_runtime/model_providers/__base/large_language_model.py index bf864ca227..0f909646a1 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ b/api/graphon/model_runtime/model_providers/__base/large_language_model.py @@ -1,27 +1,24 @@ import logging import time import uuid -from collections.abc import Callable, Generator, Iterator, Sequence +from collections.abc import Callable, Generator, Iterator, Mapping, Sequence from typing import Union -from pydantic import ConfigDict - -from configs import dify_config -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.callbacks.logging_callback import LoggingCallback +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentUnionTypes, PromptMessageTool, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.model_entities import ( ModelType, PriceType, ) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -140,11 +137,9 @@ def _build_llm_result_from_chunks( ) -def _invoke_llm_via_plugin( +def _invoke_llm_via_runtime( *, - tenant_id: str, - user_id: str, - plugin_id: str, + llm_model: "LargeLanguageModel", provider: str, model: str, credentials: dict, @@ -154,25 +149,19 @@ def _invoke_llm_via_plugin( stop: Sequence[str] | None, stream: bool, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_llm( - tenant_id=tenant_id, - user_id=user_id, - plugin_id=plugin_id, + return llm_model.model_runtime.invoke_llm( provider=provider, model=model, credentials=credentials, model_parameters=model_parameters, prompt_messages=list(prompt_messages), tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) -def _normalize_non_stream_plugin_result( +def _normalize_non_stream_runtime_result( model: str, prompt_messages: Sequence[PromptMessage], result: Union[LLMResult, Iterator[LLMResultChunk]], @@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel): model_type: ModelType = ModelType.LLM - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, @@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ @@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if dify_config.DEBUG: + if logger.isEnabledFor(logging.DEBUG): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - result = _invoke_llm_via_plugin( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + result = _invoke_llm_via_runtime( + llm_model=self, + provider=self.provider, model=model, credentials=credentials, model_parameters=model_parameters, @@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel): ) if not stream: - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: @@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) @@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) elif isinstance(result, LLMResult): @@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) # Following https://github.com/langgenius/dify/issues/17799, @@ -342,7 +320,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ @@ -384,7 +362,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -415,7 +393,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :return: """ - if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_llm_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - return 0 + return self.model_runtime.get_llm_num_tokens( + provider=self.provider, + model_type=self.model_type, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + ) def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int @@ -504,7 +474,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -517,7 +487,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -532,7 +502,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -546,7 +516,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -560,7 +530,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ _run_callbacks( callbacks, @@ -575,7 +545,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -589,7 +559,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -603,7 +573,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -619,7 +589,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -633,7 +603,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -647,7 +617,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -663,6 +633,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) diff --git a/api/graphon/model_runtime/model_providers/__base/moderation_model.py b/api/graphon/model_runtime/model_providers/__base/moderation_model.py new file mode 100644 index 0000000000..01f6842998 --- /dev/null +++ b/api/graphon/model_runtime/model_providers/__base/moderation_model.py @@ -0,0 +1,33 @@ +import time + +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.__base.ai_model import AIModel + + +class ModerationModel(AIModel): + """ + Model class for moderation model. + """ + + model_type: ModelType = ModelType.MODERATION + + def invoke(self, model: str, credentials: dict, text: str) -> bool: + """ + Invoke moderation model + + :param model: model name + :param credentials: model credentials + :param text: text to moderate + :return: false if text is safe, true otherwise + """ + self.started_at = time.perf_counter() + + try: + return self.model_runtime.invoke_moderation( + provider=self.provider, + model=model, + credentials=credentials, + text=text, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py b/api/graphon/model_runtime/model_providers/__base/rerank_model.py similarity index 61% rename from api/dify_graph/model_runtime/model_providers/__base/rerank_model.py rename to api/graphon/model_runtime/model_providers/__base/rerank_model.py index 5da2b84b95..94b2b5a4fb 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py +++ b/api/graphon/model_runtime/model_providers/__base/rerank_model.py @@ -1,6 +1,6 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.model_providers.__base.ai_model import AIModel class RerankModel(AIModel): @@ -18,7 +18,6 @@ class RerankModel(AIModel): docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -29,18 +28,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, @@ -55,11 +47,10 @@ class RerankModel(AIModel): self, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke multimodal rerank model @@ -69,18 +60,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_multimodal_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, diff --git a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py new file mode 100644 index 0000000000..4f5d648639 --- /dev/null +++ b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py @@ -0,0 +1,31 @@ +from typing import IO + +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.__base.ai_model import AIModel + + +class Speech2TextModel(AIModel): + """ + Model class for speech2text model. + """ + + model_type: ModelType = ModelType.SPEECH2TEXT + + def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: + """ + Invoke speech to text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :return: text for given audio file + """ + try: + return self.model_runtime.invoke_speech_to_text( + provider=self.provider, + model=model, + credentials=credentials, + file=file, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py similarity index 65% rename from api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py rename to api/graphon/model_runtime/model_providers/__base/text_embedding_model.py index 3438da2ada..c8b4a0a6af 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,9 +1,6 @@ -from pydantic import ConfigDict - -from core.entities.embedding_type import EmbeddingInputType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.model_providers.__base.ai_model import AIModel class TextEmbeddingModel(AIModel): @@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel): model_type: ModelType = ModelType.TEXT_EMBEDDING - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, credentials: dict, texts: list[str] | None = None, multimodel_documents: list[dict] | None = None, - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ @@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel): :param credentials: model credentials :param texts: texts to embed :param files: files to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ - from core.plugin.impl.model import PluginModelClient - try: - plugin_model_manager = PluginModelClient() if texts: - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_text_embedding( + provider=self.provider, model=model, credentials=credentials, texts=texts, input_type=input_type, ) if multimodel_documents: - return plugin_model_manager.invoke_multimodal_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_embedding( + provider=self.provider, model=model, credentials=credentials, documents=multimodel_documents, @@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel): :param texts: texts to embed :return: """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_text_embedding_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_text_embedding_num_tokens( + provider=self.provider, model=model, credentials=credentials, texts=texts, diff --git a/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py rename to api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py b/api/graphon/model_runtime/model_providers/__base/tts_model.py similarity index 57% rename from api/dify_graph/model_runtime/model_providers/__base/tts_model.py rename to api/graphon/model_runtime/model_providers/__base/tts_model.py index 0656529f22..6846f3c403 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py +++ b/api/graphon/model_runtime/model_providers/__base/tts_model.py @@ -1,10 +1,8 @@ import logging from collections.abc import Iterable -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -16,38 +14,25 @@ class TTSModel(AIModel): model_type: ModelType = ModelType.TTS - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, - tenant_id: str, credentials: dict, content_text: str, voice: str, - user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param voice: model timbre :param content_text: text content to be translated - :param user: unique user id :return: translated audio file """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_tts( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_tts( + provider=self.provider, model=model, credentials=credentials, content_text=content_text, @@ -65,14 +50,8 @@ class TTSModel(AIModel): :param credentials: The credentials required to access the TTS model. :return: A list of voices supported by the TTS model. """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_tts_model_voices( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_tts_model_voices( + provider=self.provider, model=model, credentials=credentials, language=language, diff --git a/api/dify_graph/model_runtime/utils/__init__.py b/api/graphon/model_runtime/model_providers/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/utils/__init__.py rename to api/graphon/model_runtime/model_providers/__init__.py diff --git a/api/dify_graph/model_runtime/model_providers/_position.yaml b/api/graphon/model_runtime/model_providers/_position.yaml similarity index 100% rename from api/dify_graph/model_runtime/model_providers/_position.yaml rename to api/graphon/model_runtime/model_providers/_position.yaml diff --git a/api/graphon/model_runtime/model_providers/model_provider_factory.py b/api/graphon/model_runtime/model_providers/model_provider_factory.py new file mode 100644 index 0000000000..1ea30c7120 --- /dev/null +++ b/api/graphon/model_runtime/model_providers/model_provider_factory.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from collections.abc import Sequence + +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator +from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) + + +class ModelProviderFactory: + """Factory for provider schemas and model-type instances backed by a runtime adapter.""" + + def __init__(self, model_runtime: ModelRuntime): + if model_runtime is None: + raise ValueError("model_runtime is required.") + self.model_runtime = model_runtime + + def get_providers(self) -> Sequence[ProviderEntity]: + """ + Get all providers. + """ + return list(self.get_model_providers()) + + def get_model_providers(self) -> Sequence[ProviderEntity]: + """ + Get all model providers exposed by the runtime adapter. + """ + return self.model_runtime.fetch_model_providers() + + def get_provider_schema(self, provider: str) -> ProviderEntity: + """ + Get provider schema. + """ + return self.get_model_provider(provider=provider) + + def get_model_provider(self, provider: str) -> ProviderEntity: + """ + Get provider schema. + """ + provider_entity = self._resolve_provider(provider) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + + return provider_entity + + def provider_credentials_validate(self, *, provider: str, credentials: dict): + """ + Validate provider credentials. + """ + provider_entity = self.get_model_provider(provider=provider) + + provider_credential_schema = provider_entity.provider_credential_schema + if not provider_credential_schema: + raise ValueError(f"Provider {provider} does not have provider_credential_schema") + + validator = ProviderCredentialSchemaValidator(provider_credential_schema) + filtered_credentials = validator.validate_and_filter(credentials) + + self.model_runtime.validate_provider_credentials( + provider=provider_entity.provider, + credentials=filtered_credentials, + ) + + return filtered_credentials + + def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): + """ + Validate model credentials. + """ + provider_entity = self.get_model_provider(provider=provider) + + model_credential_schema = provider_entity.model_credential_schema + if not model_credential_schema: + raise ValueError(f"Provider {provider} does not have model_credential_schema") + + validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) + filtered_credentials = validator.validate_and_filter(credentials) + + self.model_runtime.validate_model_credentials( + provider=provider_entity.provider, + model_type=model_type, + model=model, + credentials=filtered_credentials, + ) + + return filtered_credentials + + def get_model_schema( + self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None + ) -> AIModelEntity | None: + """ + Get model schema. + """ + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_model_schema( + provider=provider_entity.provider, + model_type=model_type, + model=model, + credentials=credentials or {}, + ) + + def get_models( + self, + *, + provider: str | None = None, + model_type: ModelType | None = None, + provider_configs: list[ProviderConfig] | None = None, + ) -> list[SimpleProviderEntity]: + """ + Get all models for given model type. + """ + providers = [] + for provider_entity in self.get_model_providers(): + if provider and not self._matches_provider(provider_entity, provider): + continue + + if model_type and model_type not in provider_entity.supported_model_types: + continue + + simple_provider_schema = provider_entity.to_simple_provider() + if model_type is not None: + simple_provider_schema.models = [ + model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type + ] + providers.append(simple_provider_schema) + + return providers + + def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: + """ + Get model type instance by provider name and model type. + """ + provider_schema = self.get_model_provider(provider) + + if model_type == ModelType.LLM: + return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TEXT_EMBEDDING: + return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.RERANK: + return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.SPEECH2TEXT: + return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.MODERATION: + return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TTS: + return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + + raise ValueError(f"Unsupported model type: {model_type}") + + def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + """ + Get provider icon. + """ + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) + + def _resolve_provider(self, provider: str) -> ProviderEntity | None: + return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) + + @staticmethod + def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: + return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/graphon/model_runtime/runtime.py b/api/graphon/model_runtime/runtime.py new file mode 100644 index 0000000000..79862bab8b --- /dev/null +++ b/api/graphon/model_runtime/runtime.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from collections.abc import Generator, Iterable, Sequence +from typing import IO, Any, Protocol, Union, runtime_checkable + +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult + + +@runtime_checkable +class ModelRuntime(Protocol): + """Port for provider discovery, schema lookup, and model execution. + + `provider` is the model runtime's canonical provider identifier. Adapters may + derive transport-specific details from it, but those details stay outside + this boundary. + """ + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: ... + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: ... + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: ... + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: ... + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: ... + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: ... + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: ... + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: ... diff --git a/api/dify_graph/nodes/answer/__init__.py b/api/graphon/model_runtime/schema_validators/__init__.py similarity index 100% rename from api/dify_graph/nodes/answer/__init__.py rename to api/graphon/model_runtime/schema_validators/__init__.py diff --git a/api/dify_graph/model_runtime/schema_validators/common_validator.py b/api/graphon/model_runtime/schema_validators/common_validator.py similarity index 97% rename from api/dify_graph/model_runtime/schema_validators/common_validator.py rename to api/graphon/model_runtime/schema_validators/common_validator.py index 04cdb8e4f7..984507081b 100644 --- a/api/dify_graph/model_runtime/schema_validators/common_validator.py +++ b/api/graphon/model_runtime/schema_validators/common_validator.py @@ -1,6 +1,6 @@ from typing import Union, cast -from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType +from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType class CommonValidator: diff --git a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py similarity index 78% rename from api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py rename to api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py index a97796e98f..9e4830c1b7 100644 --- a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py @@ -1,6 +1,6 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ModelCredentialSchema +from graphon.model_runtime.schema_validators.common_validator import CommonValidator class ModelCredentialSchemaValidator(CommonValidator): diff --git a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py similarity index 79% rename from api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py rename to api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py index 2fed75a76c..05fd3ce142 100644 --- a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -1,5 +1,5 @@ -from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator +from graphon.model_runtime.entities.provider_entities import ProviderCredentialSchema +from graphon.model_runtime.schema_validators.common_validator import CommonValidator class ProviderCredentialSchemaValidator(CommonValidator): diff --git a/api/dify_graph/nodes/end/__init__.py b/api/graphon/model_runtime/utils/__init__.py similarity index 100% rename from api/dify_graph/nodes/end/__init__.py rename to api/graphon/model_runtime/utils/__init__.py diff --git a/api/dify_graph/model_runtime/utils/encoders.py b/api/graphon/model_runtime/utils/encoders.py similarity index 83% rename from api/dify_graph/model_runtime/utils/encoders.py rename to api/graphon/model_runtime/utils/encoders.py index c85152463e..13abf74767 100644 --- a/api/dify_graph/model_runtime/utils/encoders.py +++ b/api/graphon/model_runtime/utils/encoders.py @@ -1,7 +1,7 @@ import dataclasses import datetime from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network @@ -99,7 +99,7 @@ def jsonable_encoder( exclude_defaults: bool = False, exclude_none: bool = False, custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - sqlalchemy_safe: bool = True, + excluded_key_prefixes: Sequence[str] = (), ) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: @@ -126,7 +126,7 @@ def jsonable_encoder( obj_dict, exclude_none=exclude_none, exclude_defaults=exclude_defaults, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if dataclasses.is_dataclass(obj): # Ensure obj is a dataclass instance, not a dataclass type @@ -139,7 +139,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if isinstance(obj, Enum): return obj.value @@ -152,26 +152,28 @@ def jsonable_encoder( if isinstance(obj, dict): encoded_dict = {} for key, value in obj.items(): - if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( - value is not None or not exclude_none - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value + if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): + continue + if value is None and exclude_none: + continue + + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] @@ -184,7 +186,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) ) return encoded_list @@ -212,5 +214,5 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) diff --git a/api/dify_graph/node_events/__init__.py b/api/graphon/node_events/__init__.py similarity index 95% rename from api/dify_graph/node_events/__init__.py rename to api/graphon/node_events/__init__.py index a9bef8f9a2..a2bbf9f176 100644 --- a/api/dify_graph/node_events/__init__.py +++ b/api/graphon/node_events/__init__.py @@ -21,6 +21,7 @@ from .node import ( RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) __all__ = [ @@ -43,4 +44,5 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "VariableUpdatedEvent", ] diff --git a/api/dify_graph/node_events/agent.py b/api/graphon/node_events/agent.py similarity index 100% rename from api/dify_graph/node_events/agent.py rename to api/graphon/node_events/agent.py diff --git a/api/dify_graph/node_events/base.py b/api/graphon/node_events/base.py similarity index 86% rename from api/dify_graph/node_events/base.py rename to api/graphon/node_events/base.py index 2f6259ae7d..dcd1672428 100644 --- a/api/dify_graph/node_events/base.py +++ b/api/graphon/node_events/base.py @@ -3,8 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage class NodeEventBase(BaseModel): diff --git a/api/dify_graph/node_events/iteration.py b/api/graphon/node_events/iteration.py similarity index 100% rename from api/dify_graph/node_events/iteration.py rename to api/graphon/node_events/iteration.py diff --git a/api/dify_graph/node_events/loop.py b/api/graphon/node_events/loop.py similarity index 100% rename from api/dify_graph/node_events/loop.py rename to api/graphon/node_events/loop.py diff --git a/api/dify_graph/node_events/node.py b/api/graphon/node_events/node.py similarity index 79% rename from api/dify_graph/node_events/node.py rename to api/graphon/node_events/node.py index 2e3973b8fa..17f1494cf2 100644 --- a/api/dify_graph/node_events/node.py +++ b/api/graphon/node_events/node.py @@ -4,10 +4,11 @@ from typing import Any from pydantic import Field -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.file import File -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult +from graphon.entities.pause_reason import PauseReason +from graphon.file import File +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult +from graphon.variables.variables import Variable from .base import NodeEventBase @@ -45,6 +46,12 @@ class StreamCompletedEvent(NodeEventBase): node_run_result: NodeRunResult = Field(..., description="run result") +class VariableUpdatedEvent(NodeEventBase): + """Notify the engine that a single variable should be applied to the shared pool.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/graphon/nodes/__init__.py b/api/graphon/nodes/__init__.py new file mode 100644 index 0000000000..2d376d104d --- /dev/null +++ b/api/graphon/nodes/__init__.py @@ -0,0 +1,3 @@ +from graphon.enums import BuiltinNodeTypes + +__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/variable_assigner/__init__.py b/api/graphon/nodes/answer/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/__init__.py rename to api/graphon/nodes/answer/__init__.py diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/graphon/nodes/answer/answer_node.py similarity index 83% rename from api/dify_graph/nodes/answer/answer_node.py rename to api/graphon/nodes/answer/answer_node.py index 4286e1a492..c5261a7939 100644 --- a/api/dify_graph/nodes/answer/answer_node.py +++ b/api/graphon/nodes/answer/answer_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.answer.entities import AnswerNodeData -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.variables import ArrayFileSegment, FileSegment, Segment +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.answer.entities import AnswerNodeData +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.variables import ArrayFileSegment, FileSegment, Segment class AnswerNode(Node[AnswerNodeData]): diff --git a/api/dify_graph/nodes/answer/entities.py b/api/graphon/nodes/answer/entities.py similarity index 93% rename from api/dify_graph/nodes/answer/entities.py rename to api/graphon/nodes/answer/entities.py index cd82df1ac4..c49f1f3895 100644 --- a/api/dify_graph/nodes/answer/entities.py +++ b/api/graphon/nodes/answer/entities.py @@ -3,8 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AnswerNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/base/__init__.py b/api/graphon/nodes/base/__init__.py similarity index 100% rename from api/dify_graph/nodes/base/__init__.py rename to api/graphon/nodes/base/__init__.py diff --git a/api/dify_graph/nodes/base/entities.py b/api/graphon/nodes/base/entities.py similarity index 96% rename from api/dify_graph/nodes/base/entities.py rename to api/graphon/nodes/base/entities.py index 4f8b2682e1..94b88c097d 100644 --- a/api/dify_graph/nodes/base/entities.py +++ b/api/graphon/nodes/base/entities.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, field_validator -from dify_graph.entities.base_node_data import BaseNodeData +from graphon.entities.base_node_data import BaseNodeData class VariableSelector(BaseModel): diff --git a/api/dify_graph/nodes/base/node.py b/api/graphon/nodes/base/node.py similarity index 91% rename from api/dify_graph/nodes/base/node.py rename to api/graphon/nodes/base/node.py index 56b46a5894..613ff4f037 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/graphon/nodes/base/node.py @@ -4,23 +4,23 @@ import logging import operator from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence +from datetime import UTC, datetime from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData, RetryConfig +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus, ) -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, @@ -39,8 +39,9 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) -from dify_graph.node_events import ( +from graphon.node_events import ( AgentLogEvent, HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -58,9 +59,9 @@ from dify_graph.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) -from dify_graph.runtime import GraphRuntimeState -from libs.datetime_utils import naive_utc_now +from graphon.runtime import GraphRuntimeState NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -68,23 +69,6 @@ _MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) -class DifyRunContextProtocol(Protocol): - tenant_id: str - app_id: str - user_id: str - user_from: Any - invoke_from: Any - - -class _MappingDifyRunContext: - def __init__(self, mapping: Mapping[str, Any]) -> None: - self.tenant_id = str(mapping["tenant_id"]) - self.app_id = str(mapping["app_id"]) - self.user_id = str(mapping["user_id"]) - self.user_from = mapping["user_from"] - self.invoke_from = mapping["invoke_from"] - - class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -177,8 +161,9 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under the - # canonical workflow namespaces. + # Only treat nodes from the base graphon package as production + # registrations. Higher-layer packages may still register subclasses, + # but graphon itself should not know their module identities. # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -186,7 +171,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): + if module_name.startswith("graphon.nodes."): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -263,16 +248,25 @@ class Node(Generic[NodeDataT]): self._node_id = node_id self._node_execution_id: str = "" - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) self._node_data = self.validate_node_data(config["data"]) self.post_init() @classmethod - def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model.""" - return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + def validate_node_data(cls, node_data: BaseNodeData | Mapping[str, Any]) -> NodeDataT: + """Validate shared graph node payloads against the subclass-declared NodeData model. + + Re-validate from a dumped payload instead of `from_attributes=True` so compatibility + extras stored on `BaseNodeData` survive the handoff to the concrete node data model. + Human Input delivery methods are one such extra field until graphon owns that schema. + """ + if isinstance(node_data, BaseNodeData): + payload = node_data.model_dump(mode="python") + else: + payload = dict(node_data) + return cast(NodeDataT, cls._node_data_type.model_validate(payload)) def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" @@ -299,25 +293,6 @@ class Node(Generic[NodeDataT]): raise ValueError(f"run_context missing required key: {key}") return value - def require_dify_context(self) -> DifyRunContextProtocol: - raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) - if raw_ctx is None: - raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") - - if isinstance(raw_ctx, Mapping): - missing_keys = [ - key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx - ] - if missing_keys: - raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") - return _MappingDifyRunContext(raw_ctx) - - for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): - if not hasattr(raw_ctx, attr): - raise TypeError(f"invalid dify context object, missing attribute: {attr}") - - return cast(DifyRunContextProtocol, raw_ctx) - @property def execution_id(self) -> str: return self._node_execution_id @@ -364,7 +339,7 @@ class Node(Generic[NodeDataT]): def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) # Create and push start event with required fields start_event = NodeRunStartedEvent( @@ -406,7 +381,7 @@ class Node(Generic[NodeDataT]): error=str(e), error_type="WorkflowNodeError", ) - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) yield NodeRunFailedEvent( id=self.execution_id, node_id=self._node_id, @@ -570,7 +545,7 @@ class Node(Generic[NodeDataT]): return self._node_data def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( @@ -611,7 +586,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -637,6 +612,15 @@ class Node(Generic[NodeDataT]): f"Node {self._node_id} does not support status {event.node_run_result.status}" ) + @_dispatch.register + def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + variable=event.variable, + ) + @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( @@ -793,16 +777,11 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - from core.rag.entities.citation_metadata import RetrievalSourceMetadata - - retriever_resources = [ - RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources - ] return NodeRunRetrieverResourceEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, - retriever_resources=retriever_resources, + retriever_resources=event.retriever_resources, context=event.context, node_version=self.version(), ) diff --git a/api/dify_graph/nodes/base/template.py b/api/graphon/nodes/base/template.py similarity index 98% rename from api/dify_graph/nodes/base/template.py rename to api/graphon/nodes/base/template.py index 5976e808e3..311de4a6ea 100644 --- a/api/dify_graph/nodes/base/template.py +++ b/api/graphon/nodes/base/template.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Union -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.base.variable_template_parser import VariableTemplateParser @dataclass(frozen=True) diff --git a/api/dify_graph/nodes/base/usage_tracking_mixin.py b/api/graphon/nodes/base/usage_tracking_mixin.py similarity index 89% rename from api/dify_graph/nodes/base/usage_tracking_mixin.py rename to api/graphon/nodes/base/usage_tracking_mixin.py index bd49419fd3..955bfe6726 100644 --- a/api/dify_graph/nodes/base/usage_tracking_mixin.py +++ b/api/graphon/nodes/base/usage_tracking_mixin.py @@ -1,5 +1,5 @@ -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState class LLMUsageTrackingMixin: diff --git a/api/dify_graph/nodes/base/variable_template_parser.py b/api/graphon/nodes/base/variable_template_parser.py similarity index 100% rename from api/dify_graph/nodes/base/variable_template_parser.py rename to api/graphon/nodes/base/variable_template_parser.py diff --git a/api/dify_graph/nodes/code/__init__.py b/api/graphon/nodes/code/__init__.py similarity index 100% rename from api/dify_graph/nodes/code/__init__.py rename to api/graphon/nodes/code/__init__.py diff --git a/api/dify_graph/nodes/code/code_node.py b/api/graphon/nodes/code/code_node.py similarity index 97% rename from api/dify_graph/nodes/code/code_node.py rename to api/graphon/nodes/code/code_node.py index 82d5fced62..c2eea0bec1 100644 --- a/api/dify_graph/nodes/code/code_node.py +++ b/api/graphon/nodes/code/code_node.py @@ -3,14 +3,14 @@ from decimal import Decimal from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.segments import ArrayFileSegment -from dify_graph.variables.types import SegmentType +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.variables.segments import ArrayFileSegment +from graphon.variables.types import SegmentType from .exc import ( CodeNodeError, @@ -19,8 +19,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class WorkflowCodeExecutor(Protocol): diff --git a/api/dify_graph/nodes/code/entities.py b/api/graphon/nodes/code/entities.py similarity index 85% rename from api/dify_graph/nodes/code/entities.py rename to api/graphon/nodes/code/entities.py index 55b4ee4862..dc89d64495 100644 --- a/api/dify_graph/nodes/code/entities.py +++ b/api/graphon/nodes/code/entities.py @@ -3,10 +3,10 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.entities import VariableSelector +from graphon.variables.types import SegmentType class CodeLanguage(StrEnum): diff --git a/api/dify_graph/nodes/code/exc.py b/api/graphon/nodes/code/exc.py similarity index 100% rename from api/dify_graph/nodes/code/exc.py rename to api/graphon/nodes/code/exc.py diff --git a/api/dify_graph/nodes/code/limits.py b/api/graphon/nodes/code/limits.py similarity index 100% rename from api/dify_graph/nodes/code/limits.py rename to api/graphon/nodes/code/limits.py diff --git a/api/dify_graph/nodes/document_extractor/__init__.py b/api/graphon/nodes/document_extractor/__init__.py similarity index 100% rename from api/dify_graph/nodes/document_extractor/__init__.py rename to api/graphon/nodes/document_extractor/__init__.py diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/graphon/nodes/document_extractor/entities.py similarity index 73% rename from api/dify_graph/nodes/document_extractor/entities.py rename to api/graphon/nodes/document_extractor/entities.py index 1110cc2710..026a0cd224 100644 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ b/api/graphon/nodes/document_extractor/entities.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from dataclasses import dataclass -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class DocumentExtractorNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/document_extractor/exc.py b/api/graphon/nodes/document_extractor/exc.py similarity index 100% rename from api/dify_graph/nodes/document_extractor/exc.py rename to api/graphon/nodes/document_extractor/exc.py diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/graphon/nodes/document_extractor/node.py similarity index 98% rename from api/dify_graph/nodes/document_extractor/node.py rename to api/graphon/nodes/document_extractor/node.py index 27196f1aca..be46481e7d 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/graphon/nodes/document_extractor/node.py @@ -21,14 +21,14 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, file_manager -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment, FileSegment +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, file_manager +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.protocols import HttpClientProtocol +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment, FileSegment from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -36,8 +36,8 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class DocumentExtractorNode(Node[DocumentExtractorNodeData]): diff --git a/api/dify_graph/nodes/variable_assigner/common/__init__.py b/api/graphon/nodes/end/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/common/__init__.py rename to api/graphon/nodes/end/__init__.py diff --git a/api/dify_graph/nodes/end/end_node.py b/api/graphon/nodes/end/end_node.py similarity index 82% rename from api/dify_graph/nodes/end/end_node.py rename to api/graphon/nodes/end/end_node.py index 1f5cfab22b..11b9e58644 100644 --- a/api/dify_graph/nodes/end/end_node.py +++ b/api/graphon/nodes/end/end_node.py @@ -1,8 +1,8 @@ -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.end.entities import EndNodeData +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template +from graphon.nodes.end.entities import EndNodeData class EndNode(Node[EndNodeData]): diff --git a/api/dify_graph/nodes/end/entities.py b/api/graphon/nodes/end/entities.py similarity index 76% rename from api/dify_graph/nodes/end/entities.py rename to api/graphon/nodes/end/entities.py index be7f0c8de8..839aed7e4b 100644 --- a/api/dify_graph/nodes/end/entities.py +++ b/api/graphon/nodes/end/entities.py @@ -1,8 +1,8 @@ from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import OutputVariableEntity +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.entities import OutputVariableEntity class EndNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/http_request/__init__.py b/api/graphon/nodes/http_request/__init__.py similarity index 100% rename from api/dify_graph/nodes/http_request/__init__.py rename to api/graphon/nodes/http_request/__init__.py diff --git a/api/dify_graph/nodes/http_request/config.py b/api/graphon/nodes/http_request/config.py similarity index 100% rename from api/dify_graph/nodes/http_request/config.py rename to api/graphon/nodes/http_request/config.py diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/graphon/nodes/http_request/entities.py similarity index 98% rename from api/dify_graph/nodes/http_request/entities.py rename to api/graphon/nodes/http_request/entities.py index f594d58ae6..6fa067bdd1 100644 --- a/api/dify_graph/nodes/http_request/entities.py +++ b/api/graphon/nodes/http_request/entities.py @@ -8,8 +8,8 @@ import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" diff --git a/api/dify_graph/nodes/http_request/exc.py b/api/graphon/nodes/http_request/exc.py similarity index 100% rename from api/dify_graph/nodes/http_request/exc.py rename to api/graphon/nodes/http_request/exc.py diff --git a/api/dify_graph/nodes/http_request/executor.py b/api/graphon/nodes/http_request/executor.py similarity index 98% rename from api/dify_graph/nodes/http_request/executor.py rename to api/graphon/nodes/http_request/executor.py index 892b0fc688..0c6f4ecd3a 100644 --- a/api/dify_graph/nodes/http_request/executor.py +++ b/api/graphon/nodes/http_request/executor.py @@ -10,9 +10,9 @@ from urllib.parse import urlencode, urlparse import httpx from json_repair import repair_json -from dify_graph.file.enums import FileTransferMethod -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayFileSegment, FileSegment +from graphon.file.enums import FileTransferMethod +from graphon.runtime import VariablePool +from graphon.variables.segments import ArrayFileSegment, FileSegment from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( @@ -246,7 +246,7 @@ class Executor: files: dict[str, list[tuple[str | None, bytes, str]]] = {} for key, files_in_segment in files_list: for file in files_in_segment: - if file.related_id is not None or ( + if file.reference is not None or ( file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None ): file_tuple = ( diff --git a/api/dify_graph/nodes/http_request/node.py b/api/graphon/nodes/http_request/node.py similarity index 89% rename from api/dify_graph/nodes/http_request/node.py rename to api/graphon/nodes/http_request/node.py index 3e5253d809..3d74347a7f 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/graphon/nodes/http_request/node.py @@ -3,17 +3,21 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayFileSegment -from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base import variable_template_parser +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.http_request.executor import Executor +from graphon.nodes.protocols import ( + FileManagerProtocol, + FileReferenceFactoryProtocol, + HttpClientProtocol, + ToolFileManagerProtocol, +) +from graphon.variables.segments import ArrayFileSegment from .config import build_http_request_config, resolve_http_request_config from .entities import ( @@ -28,8 +32,8 @@ from .exc import HttpRequestNodeError, RequestBodyError logger = logging.getLogger(__name__) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class HttpRequestNode(Node[HttpRequestNodeData]): @@ -46,6 +50,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): http_client: HttpClientProtocol, tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], file_manager: FileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, ) -> None: super().__init__( id=id, @@ -58,6 +63,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): self._http_client = http_client self._tool_file_manager_factory = tool_file_manager_factory self._file_manager = file_manager + self._file_reference_factory = file_reference_factory @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -212,7 +218,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ - dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -237,20 +242,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - conversation_id=None, file_binary=content, mimetype=mime_type, ) - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=dify_ctx.tenant_id, + file = self._file_reference_factory.build_from_mapping( + mapping={ + "tool_file_id": tool_file.id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } ) files.append(file) diff --git a/api/dify_graph/nodes/human_input/__init__.py b/api/graphon/nodes/human_input/__init__.py similarity index 100% rename from api/dify_graph/nodes/human_input/__init__.py rename to api/graphon/nodes/human_input/__init__.py diff --git a/api/graphon/nodes/human_input/entities.py b/api/graphon/nodes/human_input/entities.py new file mode 100644 index 0000000000..aa01bde145 --- /dev/null +++ b/api/graphon/nodes/human_input/entities.py @@ -0,0 +1,208 @@ +"""Human Input node entities. + +The graph package owns the workflow-facing form schema and keeps it transportable +across runtimes. Dify-specific delivery surface and recipient translation stay +outside `graphon`. +""" + +import re +from collections.abc import Mapping, Sequence +from datetime import datetime, timedelta +from typing import Any, Self + +from pydantic import BaseModel, Field, field_validator, model_validator + +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.variables.consts import SELECTORS_LENGTH + +from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit + +_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") + + +class FormInputDefault(BaseModel): + """Default configuration for form inputs.""" + + # NOTE: Ideally, a discriminated union would be used to model + # FormInputDefault. However, the UI requires preserving the previous + # value when switching between `VARIABLE` and `CONSTANT` types. This + # necessitates retaining all fields, making a discriminated union unsuitable. + + type: PlaceholderType + + # The selector of default variable, used when `type` is `VARIABLE`. + selector: Sequence[str] = Field(default_factory=tuple) # + + # The value of the default, used when `type` is `CONSTANT`. + # TODO: How should we express JSON values? + value: str = "" + + @model_validator(mode="after") + def _validate_selector(self) -> Self: + if self.type == PlaceholderType.CONSTANT: + return self + if len(self.selector) < SELECTORS_LENGTH: + raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") + return self + + +class FormInput(BaseModel): + """Form input definition.""" + + type: FormInputType + output_variable_name: str + default: FormInputDefault | None = None + + +_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +class UserAction(BaseModel): + """User action configuration.""" + + # id is the identifier for this action. + # It also serves as the identifiers of output handle. + # + # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) + id: str = Field(max_length=20) + title: str = Field(max_length=20) + button_style: ButtonStyle = ButtonStyle.DEFAULT + + @field_validator("id") + @classmethod + def _validate_id(cls, value: str) -> str: + if not _IDENTIFIER_PATTERN.match(value): + raise ValueError( + f"'{value}' is not a valid identifier. It must start with a letter or underscore, " + f"and contain only letters, numbers, or underscores." + ) + return value + + +class HumanInputNodeData(BaseNodeData): + """Human Input node data.""" + + type: NodeType = BuiltinNodeTypes.HUMAN_INPUT + form_content: str = "" + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + timeout: int = 36 + timeout_unit: TimeoutUnit = TimeoutUnit.HOUR + + @field_validator("inputs") + @classmethod + def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: + seen_names: set[str] = set() + for form_input in inputs: + name = form_input.output_variable_name + if name in seen_names: + raise ValueError(f"duplicated output_variable_name '{name}' in inputs") + seen_names.add(name) + return inputs + + @field_validator("user_actions") + @classmethod + def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: + seen_ids: set[str] = set() + for action in user_actions: + action_id = action.id + if action_id in seen_ids: + raise ValueError(f"duplicated user action id '{action_id}'") + seen_ids.add(action_id) + return user_actions + + def expiration_time(self, start_time: datetime) -> datetime: + if self.timeout_unit == TimeoutUnit.HOUR: + return start_time + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + return start_time + timedelta(days=self.timeout) + else: + raise AssertionError("unknown timeout unit.") + + def outputs_field_names(self) -> Sequence[str]: + field_names = [] + for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): + field_names.append(match.group("field_name")) + return field_names + + def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: + variable_mappings: dict[str, Sequence[str]] = {} + + def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: + for selector in selectors: + if len(selector) < SELECTORS_LENGTH: + continue + qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" + variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) + + form_template_parser = VariableTemplateParser(template=self.form_content) + _add_variable_selectors( + [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] + ) + + for input in self.inputs: + default_value = input.default + if default_value is None: + continue + if default_value.type == PlaceholderType.CONSTANT: + continue + default_value_key = ".".join(default_value.selector) + qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" + variable_mappings[qualified_variable_mapping_key] = default_value.selector + + return variable_mappings + + def find_action_text(self, action_id: str) -> str: + """ + Resolve action display text by id. + """ + for action in self.user_actions: + if action.id == action_id: + return action.title + return action_id + + +class FormDefinition(BaseModel): + form_content: str + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + rendered_content: str + expiration_time: datetime + + # this is used to store the resolved default values + default_values: dict[str, Any] = Field(default_factory=dict) + + # node_title records the title of the HumanInput node. + node_title: str | None = None + + # display_in_ui controls whether the form should be displayed in UI surfaces. + display_in_ui: bool | None = None + + +class HumanInputSubmissionValidationError(ValueError): + pass + + +def validate_human_input_submission( + *, + inputs: Sequence[FormInput], + user_actions: Sequence[UserAction], + selected_action_id: str, + form_data: Mapping[str, Any], +) -> None: + available_actions = {action.id for action in user_actions} + if selected_action_id not in available_actions: + raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") + + provided_inputs = set(form_data.keys()) + missing_inputs = [ + form_input.output_variable_name + for form_input in inputs + if form_input.output_variable_name not in provided_inputs + ] + + if missing_inputs: + missing_list = ", ".join(missing_inputs) + raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/graphon/nodes/human_input/enums.py similarity index 76% rename from api/dify_graph/nodes/human_input/enums.py rename to api/graphon/nodes/human_input/enums.py index da85728828..3fb0ab4499 100644 --- a/api/dify_graph/nodes/human_input/enums.py +++ b/api/graphon/nodes/human_input/enums.py @@ -25,16 +25,6 @@ class HumanInputFormKind(enum.StrEnum): DELIVERY_TEST = enum.auto() # Form created for delivery tests. -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - class ButtonStyle(enum.StrEnum): """Button styles for user actions.""" @@ -63,10 +53,3 @@ class PlaceholderType(enum.StrEnum): VARIABLE = enum.auto() CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/graphon/nodes/human_input/human_input_node.py similarity index 65% rename from api/dify_graph/nodes/human_input/human_input_node.py rename to api/graphon/nodes/human_input/human_input_node.py index 794e33d92e..fe04022877 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/graphon/nodes/human_input/human_input_node.py @@ -1,39 +1,33 @@ import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, cast -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, NodeRunResult, PauseRequestedEvent, ) -from dify_graph.node_events.base import NodeEventBase -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from libs.datetime_utils import naive_utc_now +from graphon.node_events.base import NodeEventBase +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType +from .entities import HumanInputNodeData +from .enums import HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: - from dify_graph.entities.graph_init_params import GraphInitParams - from dify_graph.runtime.graph_runtime_state import GraphRuntimeState + from graphon.entities.graph_init_params import GraphInitParams + from graphon.runtime.graph_runtime_state import GraphRuntimeState _SELECTED_BRANCH_KEY = "selected_branch" -_INVOKE_FROM_DEBUGGER = "debugger" -_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -56,7 +50,6 @@ class HumanInputNode(Node[HumanInputNodeData]): ) _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository _OUTPUT_FIELD_ACTION_ID = "__action_id" _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" @@ -67,7 +60,8 @@ class HumanInputNode(Node[HumanInputNodeData]): config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository, + runtime: HumanInputNodeRuntimeProtocol | None = None, + form_repository: object | None = None, ) -> None: super().__init__( id=id, @@ -75,7 +69,14 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._form_repository = form_repository + resolved_runtime = runtime + if resolved_runtime is None: + raise ValueError("runtime is required") + if form_repository is not None: + with_form_repository = getattr(resolved_runtime, "with_form_repository", None) + if callable(with_form_repository): + resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) + self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime @classmethod def version(cls) -> str: @@ -128,13 +129,7 @@ class HumanInputNode(Node[HumanInputNodeData]): return None - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): + def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): required_event = self._human_input_required_event(form_entity) pause_requested_event = PauseRequestedEvent(reason=required_event) return pause_requested_event @@ -157,56 +152,16 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults - def _should_require_console_recipient(self) -> bool: - invoke_from = self._invoke_from_value() - if invoke_from == _INVOKE_FROM_DEBUGGER: - return True - if invoke_from == _INVOKE_FROM_EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - dify_ctx = self.require_dify_context() - invoke_from = self._invoke_from_value() - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=invoke_from == _INVOKE_FROM_DEBUGGER, - user_id=dify_ctx.user_id, - ) - for method in enabled_methods - ] - - def _invoke_from_value(self) -> str: - invoke_from = self.require_dify_context().invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(getattr(invoke_from, "value", invoke_from)) - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: + def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") return HumanInputRequired( form_id=form_entity.id, form_content=form_entity.rendered_content, inputs=node_data.inputs, actions=node_data.user_actions, - display_in_ui=display_in_ui, node_id=self.id, node_title=node_data.title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -217,49 +172,32 @@ class HumanInputNode(Node[HumanInputNodeData]): This method will: 1. Generate a unique form ID 2. Create form content with variable substitution - 3. Create form in database + 3. Persist the form through the configured repository 4. Send form via configured delivery methods 5. Suspend workflow execution 6. Wait for form submission to resume """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - dify_ctx = self.require_dify_context() + form = self._runtime.get_form(node_id=self.id) if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=dify_ctx.app_id, - workflow_execution_id=self._workflow_execution_id, + form_entity = self._runtime.create_form( node_id=self.id, - form_config=self._node_data, + node_data=self._node_data, rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - dify_ctx.user_id - if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} - else None - ), - backstage_recipient_required=True, ) - form_entity = self._form_repository.create_form(params) - # Create human input required event logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + "Human Input node suspended workflow for form. node_id=%s, form_id=%s", self.id, form_entity.id, ) yield self._form_to_pause_event(form_entity) return - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): + if form.status in { + HumanInputFormStatus.TIMEOUT, + HumanInputFormStatus.EXPIRED, + } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): yield HumanInputFormTimeoutEvent( node_title=self._node_data.title, expiration_time=form.expiration_time, diff --git a/api/dify_graph/nodes/if_else/__init__.py b/api/graphon/nodes/if_else/__init__.py similarity index 100% rename from api/dify_graph/nodes/if_else/__init__.py rename to api/graphon/nodes/if_else/__init__.py diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/graphon/nodes/if_else/entities.py similarity index 77% rename from api/dify_graph/nodes/if_else/entities.py rename to api/graphon/nodes/if_else/entities.py index ff09f3c023..d59b782747 100644 --- a/api/dify_graph/nodes/if_else/entities.py +++ b/api/graphon/nodes/if_else/entities.py @@ -2,9 +2,9 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.utils.condition.entities import Condition +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/graphon/nodes/if_else/if_else_node.py similarity index 87% rename from api/dify_graph/nodes/if_else/if_else_node.py rename to api/graphon/nodes/if_else/if_else_node.py index 7c0370e48c..81e934971a 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/graphon/nodes/if_else/if_else_node.py @@ -3,13 +3,13 @@ from typing import Any, Literal from typing_extensions import deprecated -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.runtime import VariablePool -from dify_graph.utils.condition.entities import Condition -from dify_graph.utils.condition.processor import ConditionProcessor +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.runtime import VariablePool +from graphon.utils.condition.entities import Condition +from graphon.utils.condition.processor import ConditionProcessor class IfElseNode(Node[IfElseNodeData]): @@ -57,8 +57,8 @@ class IfElseNode(Node[IfElseNodeData]): break else: - # TODO: Update database then remove this - # Fallback to old structure if cases are not defined + # TODO: Remove this once all graph definitions use the `cases` structure. + # Fallback to the legacy node shape when `cases` are not defined. input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, diff --git a/api/dify_graph/nodes/iteration/__init__.py b/api/graphon/nodes/iteration/__init__.py similarity index 100% rename from api/dify_graph/nodes/iteration/__init__.py rename to api/graphon/nodes/iteration/__init__.py diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/graphon/nodes/iteration/entities.py similarity index 89% rename from api/dify_graph/nodes/iteration/entities.py rename to api/graphon/nodes/iteration/entities.py index 58fd112b12..30b6e4bea8 100644 --- a/api/dify_graph/nodes/iteration/entities.py +++ b/api/graphon/nodes/iteration/entities.py @@ -3,9 +3,9 @@ from typing import Any from pydantic import Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base import BaseIterationNodeData, BaseIterationState class ErrorHandleMode(StrEnum): diff --git a/api/dify_graph/nodes/iteration/exc.py b/api/graphon/nodes/iteration/exc.py similarity index 82% rename from api/dify_graph/nodes/iteration/exc.py rename to api/graphon/nodes/iteration/exc.py index d9947e09bc..7b6af61b9d 100644 --- a/api/dify_graph/nodes/iteration/exc.py +++ b/api/graphon/nodes/iteration/exc.py @@ -20,3 +20,7 @@ class IterationGraphNotFoundError(IterationNodeError): class IterationIndexNotFoundError(IterationNodeError): """Raised when the iteration index is not found.""" + + +class ChildGraphAbortedError(IterationNodeError): + """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/graphon/nodes/iteration/iteration_node.py similarity index 74% rename from api/dify_graph/nodes/iteration/iteration_node.py rename to api/graphon/nodes/iteration/iteration_node.py index 033ec8672f..c013739653 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/graphon/nodes/iteration/iteration_node.py @@ -1,27 +1,29 @@ import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from contextlib import suppress from datetime import UTC, datetime +from threading import Lock from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import ( IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -30,16 +32,15 @@ from dify_graph.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.runtime import VariablePool -from dify_graph.variables import IntegerVariable, NoneSegment -from dify_graph.variables.segments import ArrayAnySegment, ArraySegment -from dify_graph.variables.variables import Variable -from libs.datetime_utils import naive_utc_now +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.runtime import VariablePool +from graphon.variables import IntegerVariable, NoneSegment +from graphon.variables.segments import ArrayAnySegment, ArraySegment from .exc import ( + ChildGraphAbortedError, InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -49,10 +50,10 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.context import IExecutionContext - from dify_graph.graph_engine import GraphEngine + from graphon.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) @@ -93,7 +94,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self._validate_start_node() - started_at = naive_utc_now() + started_at = datetime.now(UTC).replace(tzinfo=None) iter_run_map: dict[str, float] = {} outputs: list[object] = [] usage_accumulator = [LLMUsage.empty_usage()] @@ -199,23 +200,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_engine = self._create_graph_engine(index, item) # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Sync conversation variables after each iteration completes - self._sync_conversation_variables_from_snapshot( - self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool + try: + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, ) - ) - - # Accumulate usage from this iteration - usage_accumulator[0] = self._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) + finally: + self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -233,13 +225,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks + started_child_engines: dict[int, GraphEngine] = {} + started_child_engines_lock = Lock() + merged_usage_indexes: set[int] = set() future_to_index: dict[ Future[ tuple[ float, list[GraphNodeEventBase], object | None, - dict[str, Variable], LLMUsage, ] ], @@ -248,10 +242,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): for index, item in enumerate(iterator_list_value): yield IterationNextEvent(index=index) future = executor.submit( - self._execute_single_iteration_parallel, + self._execute_tracked_iteration_parallel, index=index, item=item, - execution_context=self._capture_execution_context(), + started_child_engines=started_child_engines, + started_child_engines_lock=started_child_engines_lock, ) future_to_index[future] = index @@ -264,7 +259,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iteration_duration, events, output_value, - conversation_snapshot, iteration_usage, ) = result @@ -279,11 +273,31 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iter_run_map[str(index)] = iteration_duration usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - - # Sync conversation variables after iteration completion - self._sync_conversation_variables_from_snapshot(conversation_snapshot) + merged_usage_indexes.add(index) except Exception as e: + if index not in merged_usage_indexes: + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + if isinstance(e, ChildGraphAbortedError): + self._abort_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, + ) + self._drain_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + usage_accumulator=usage_accumulator, + merged_usage_indexes=merged_usage_indexes, + ) + raise e + # Handle errors based on error_handle_mode match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -301,48 +315,118 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs[:] = [output for output in outputs if output is not None] + @staticmethod + def _merge_graph_engine_usage( + *, + usage_accumulator: list[LLMUsage], + graph_engine: "GraphEngine | None", + ) -> None: + if graph_engine is None: + return + usage_accumulator[0] = IterationNode._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) + + def _abort_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + reason: str, + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + + graph_engine = started_child_engines.get(index) + if graph_engine is not None: + graph_engine.request_abort(reason) + + future.cancel() + + def _drain_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + usage_accumulator: list[LLMUsage], + merged_usage_indexes: set[int], + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + if future.cancelled(): + continue + + with suppress(Exception): + future.result() + + if index in merged_usage_indexes: + continue + + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + + def _execute_tracked_iteration_parallel( + self, + *, + index: int, + item: object, + started_child_engines: dict[int, "GraphEngine"], + started_child_engines_lock: Lock, + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + graph_engine = self._create_graph_engine(index, item) + with started_child_engines_lock: + started_child_engines[index] = graph_engine + + return self._execute_parallel_iteration_with_graph_engine( + index=index, + graph_engine=graph_engine, + ) + def _execute_single_iteration_parallel( self, index: int, item: object, - execution_context: "IExecutionContext", - ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: """Execute a single iteration in parallel mode and return results.""" - with execution_context: - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] + graph_engine = self._create_graph_engine(index, item) + return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) - graph_engine = self._create_graph_engine(index, item) + def _execute_parallel_iteration_with_graph_engine( + self, + *, + index: int, + graph_engine: "GraphEngine", + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + """Execute a prepared child engine in parallel mode and return results.""" + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + events: list[GraphNodeEventBase] = [] + outputs_temp: list[object] = [] - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) + # Collect events instead of yielding them directly + for event in self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs_temp, + graph_engine=graph_engine, + ): + events.append(event) - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - conversation_snapshot = self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # Get the output value from the temporary outputs list + output_value = outputs_temp[0] if outputs_temp else None + iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - return ( - iteration_duration, - events, - output_value, - conversation_snapshot, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _capture_execution_context(self) -> "IExecutionContext": - """Capture current execution context for parallel iterations.""" - from dify_graph.context import capture_current_context - - return capture_current_context() + return ( + iteration_duration, + events, + output_value, + graph_engine.graph_runtime_state.llm_usage, + ) def _handle_iteration_success( self, @@ -516,23 +600,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: - conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: - parent_pool = self.graph_runtime_state.variable_pool - parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - - current_keys = set(parent_conversations.keys()) - snapshot_keys = set(snapshot.keys()) - - for removed_key in current_keys - snapshot_keys: - parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) - - for name, variable in snapshot.items(): - parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) - def _append_iteration_info_to_event( self, event: GraphNodeEventBase, @@ -575,6 +642,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): else: outputs.append(result.to_object()) return + elif isinstance(event, GraphRunAbortedEvent): + raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) elif isinstance(event, GraphRunFailedEvent): match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -586,8 +655,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return def _create_graph_engine(self, index: int, item: object): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import ChildGraphNotFoundError # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -602,14 +671,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # append iteration variable (item, index) to variable pool variable_pool_copy.add([self._node_id, "index"], index) variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) root_node_id = self.node_data.start_node_id if root_node_id is None: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") @@ -618,9 +679,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, + variable_pool=variable_pool_copy, ) except ChildGraphNotFoundError as exc: raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/iteration/iteration_start_node.py b/api/graphon/nodes/iteration/iteration_start_node.py similarity index 61% rename from api/dify_graph/nodes/iteration/iteration_start_node.py rename to api/graphon/nodes/iteration/iteration_start_node.py index a8ecf3d83b..3a44d3d81d 100644 --- a/api/dify_graph/nodes/iteration/iteration_start_node.py +++ b/api/graphon/nodes/iteration/iteration_start_node.py @@ -1,7 +1,7 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import IterationStartNodeData +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.iteration.entities import IterationStartNodeData class IterationStartNode(Node[IterationStartNodeData]): diff --git a/api/dify_graph/nodes/list_operator/__init__.py b/api/graphon/nodes/list_operator/__init__.py similarity index 100% rename from api/dify_graph/nodes/list_operator/__init__.py rename to api/graphon/nodes/list_operator/__init__.py diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/graphon/nodes/list_operator/entities.py similarity index 93% rename from api/dify_graph/nodes/list_operator/entities.py rename to api/graphon/nodes/list_operator/entities.py index 41b3a40b78..0db1c75cdd 100644 --- a/api/dify_graph/nodes/list_operator/entities.py +++ b/api/graphon/nodes/list_operator/entities.py @@ -3,8 +3,8 @@ from enum import StrEnum from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class FilterOperator(StrEnum): diff --git a/api/dify_graph/nodes/list_operator/exc.py b/api/graphon/nodes/list_operator/exc.py similarity index 100% rename from api/dify_graph/nodes/list_operator/exc.py rename to api/graphon/nodes/list_operator/exc.py diff --git a/api/dify_graph/nodes/list_operator/node.py b/api/graphon/nodes/list_operator/node.py similarity index 97% rename from api/dify_graph/nodes/list_operator/node.py rename to api/graphon/nodes/list_operator/node.py index dc8b8904f7..dad17a8f4a 100644 --- a/api/dify_graph/nodes/list_operator/node.py +++ b/api/graphon/nodes/list_operator/node.py @@ -1,12 +1,12 @@ from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from dify_graph.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError diff --git a/api/dify_graph/nodes/llm/__init__.py b/api/graphon/nodes/llm/__init__.py similarity index 100% rename from api/dify_graph/nodes/llm/__init__.py rename to api/graphon/nodes/llm/__init__.py diff --git a/api/dify_graph/nodes/llm/entities.py b/api/graphon/nodes/llm/entities.py similarity index 89% rename from api/dify_graph/nodes/llm/entities.py rename to api/graphon/nodes/llm/entities.py index 6ca01a21da..196152548c 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/graphon/nodes/llm/entities.py @@ -3,11 +3,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode -from dify_graph.nodes.base.entities import VariableSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode +from graphon.nodes.base.entities import VariableSelector +from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig class ModelConfig(BaseModel): diff --git a/api/dify_graph/nodes/llm/exc.py b/api/graphon/nodes/llm/exc.py similarity index 100% rename from api/dify_graph/nodes/llm/exc.py rename to api/graphon/nodes/llm/exc.py diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/graphon/nodes/llm/file_saver.py similarity index 77% rename from api/dify_graph/nodes/llm/file_saver.py rename to api/graphon/nodes/llm/file_saver.py index 50e52a3b6f..0bedb42f3a 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/graphon/nodes/llm/file_saver.py @@ -1,11 +1,9 @@ import mimetypes import typing as tp -from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.tools.signature import sign_tool_file -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.protocols import HttpClientProtocol +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol class LLMFileSaver(tp.Protocol): @@ -57,17 +55,20 @@ class LLMFileSaver(tp.Protocol): class FileSaverImpl(LLMFileSaver): - _tenant_id: str - _user_id: str + _tool_file_manager: ToolFileManagerProtocol + _file_reference_factory: FileReferenceFactoryProtocol - def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): - self._user_id = user_id - self._tenant_id = tenant_id + def __init__( + self, + *, + tool_file_manager: ToolFileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, + http_client: HttpClientProtocol, + ): + self._tool_file_manager = tool_file_manager + self._file_reference_factory = file_reference_factory self._http_client = http_client - def _get_tool_file_manager(self): - return ToolFileManager() - def save_remote_url(self, url: str, file_type: FileType) -> File: http_response = self._http_client.get(url) http_response.raise_for_status() @@ -83,30 +84,24 @@ class FileSaverImpl(LLMFileSaver): file_type: FileType, extension_override: str | None = None, ) -> File: - tool_file_manager = self._get_tool_file_manager() - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - # TODO(QuantumGhost): what is conversation id? - conversation_id=None, + tool_file = self._tool_file_manager.create_file_by_raw( file_binary=data, mimetype=mime_type, ) extension_override = _validate_extension_override(extension_override) extension = _get_extension(mime_type, extension_override) - url = sign_tool_file(tool_file.id, extension) - - return File( - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=tool_file.name, - extension=extension, - mime_type=mime_type, - size=len(data), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, + return self._file_reference_factory.build_from_mapping( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": tool_file.name, + "extension": extension, + "mime_type": mime_type, + "size": len(data), + "tool_file_id": str(tool_file.id), + "related_id": str(tool_file.id), + "storage_key": tool_file.file_key, + } ) diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/graphon/nodes/llm/llm_utils.py similarity index 78% rename from api/dify_graph/nodes/llm/llm_utils.py rename to api/graphon/nodes/llm/llm_utils.py index 2be391a424..11a1d83a9d 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/graphon/nodes/llm/llm_utils.py @@ -1,31 +1,33 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, cast +import json +import logging +import re +from collections.abc import Mapping, Sequence +from typing import Any -from core.model_manager import ModelInstance -from dify_graph.file import FileType, file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( +from graphon.file import FileType, file_manager +from graphon.file.models import File +from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentType, PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, SystemPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment, FileSegment -from dify_graph.variables.segments import ArrayAnySegment, NoneSegment +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.nodes.base.entities import VariableSelector +from graphon.runtime import VariablePool +from graphon.template_rendering import Jinja2TemplateRenderer +from graphon.variables import ArrayFileSegment, FileSegment +from graphon.variables.segments import ArrayAnySegment, NoneSegment from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig from .exc import ( @@ -34,16 +36,20 @@ from .exc import ( NoPromptFoundError, TemplateTypeNotSupportError, ) -from .protocols import TemplateRenderer +from .runtime_protocols import PreparedLLMProtocol + +CONTEXT_PLACEHOLDER = "{{#context#}}" + +logger = logging.getLogger(__name__) + +VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") +MAX_RESOLVED_VALUE_LENGTH = 1024 -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - dict(model_instance.credentials), - ) +def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: + model_schema = model_instance.get_model_schema() if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") + raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") return model_schema @@ -114,9 +120,9 @@ def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence[File], - context: str | None = None, + context: str = "", memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -125,7 +131,7 @@ def fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] model_schema = fetch_model_schema(model_instance=model_instance) @@ -277,11 +283,11 @@ def fetch_prompt_messages( def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] for message in messages: @@ -300,7 +306,7 @@ def handle_list_messages( ) continue - template = message.text.replace("{#context#}", context) if context else message.text + template = message.text.replace(CONTEXT_PLACEHOLDER, context) segment_group = variable_pool.convert_template(template) file_contents: list[PromptMessageContentUnionTypes] = [] for segment in segment_group.value: @@ -335,7 +341,7 @@ def render_jinja2_message( template: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> str: if not template: return "" @@ -346,16 +352,16 @@ def render_jinja2_message( for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) + return template_renderer.render_template(template, jinja2_inputs) def handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: if template.edition_type == "jinja2": result_text = render_jinja2_message( @@ -365,7 +371,7 @@ def handle_completion_template( template_renderer=template_renderer, ) else: - template_text = template.text.replace("{#context#}", context) if context else template.text + template_text = template.text.replace(CONTEXT_PLACEHOLDER, context) result_text = variable_pool.convert_template(template_text).text return [ combine_message_content_with_role( @@ -391,7 +397,11 @@ def combine_message_content_with_role( raise NotImplementedError(f"Role {role} is not supported") -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: +def calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: rest_tokens = 2000 runtime_model_schema = fetch_model_schema(model_instance=model_instance) runtime_model_parameters = model_instance.parameters @@ -421,7 +431,7 @@ def handle_memory_chat_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, ) -> Sequence[PromptMessage]: if not memory or not memory_config: return [] @@ -436,7 +446,7 @@ def handle_memory_completion_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, ) -> str: if not memory or not memory_config: return "" @@ -475,3 +485,61 @@ def _append_file_prompts( prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + + +def _coerce_resolved_value(raw: str) -> int | float | bool | str: + """Try to restore the original type from a resolved template string. + + Variable references are always resolved to text, but completion params may + expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to + the ``temperature`` parameter). This helper attempts a JSON parse so that + ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not + valid JSON literals are returned as-is. + """ + stripped = raw.strip() + if not stripped: + return raw + + try: + parsed: object = json.loads(stripped) + except (json.JSONDecodeError, ValueError): + return raw + + if isinstance(parsed, (int, float, bool)): + return parsed + return raw + + +def resolve_completion_params_variables( + completion_params: Mapping[str, Any], + variable_pool: VariablePool, +) -> dict[str, Any]: + """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. + + Security notes: + - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to + prevent denial-of-service through excessively large variable payloads. + - This follows the same ``VariablePool.convert_template`` pattern used across + Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream + model plugin receives these values as structured JSON key-value pairs — they + are never concatenated into raw HTTP headers or SQL queries. + - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are + restored to their native type rather than sent as a bare string. + """ + resolved: dict[str, Any] = {} + for key, value in completion_params.items(): + if isinstance(value, str) and VARIABLE_PATTERN.search(value): + segment_group = variable_pool.convert_template(value) + text = segment_group.text + if len(text) > MAX_RESOLVED_VALUE_LENGTH: + logger.warning( + "Resolved value for param '%s' truncated from %d to %d chars", + key, + len(text), + MAX_RESOLVED_VALUE_LENGTH, + ) + text = text[:MAX_RESOLVED_VALUE_LENGTH] + resolved[key] = _coerce_resolved_value(text) + else: + resolved[key] = value + return resolved diff --git a/api/dify_graph/nodes/llm/node.py b/api/graphon/nodes/llm/node.py similarity index 62% rename from api/dify_graph/nodes/llm/node.py rename to api/graphon/nodes/llm/node.py index 5ed90ed7e3..4de2a95465 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/graphon/nodes/llm/node.py @@ -7,33 +7,24 @@ import logging import re import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast -from sqlalchemy import select - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.tools.signature import sign_upload_file -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, NodeType, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities import ( +from graphon.file import File, FileType, file_manager +from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkWithStructuredOutput, @@ -41,10 +32,17 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, @@ -52,22 +50,26 @@ from dify_graph.node_events import ( StreamChunkEvent, StreamCompletedEvent, ) -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import VariablePool -from dify_graph.variables import ( +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import HttpClientProtocol +from graphon.prompt_entities import CompletionModelPromptTemplate, MemoryConfig +from graphon.runtime import VariablePool +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from graphon.variables import ( ArrayFileSegment, ArraySegment, + FileSegment, NoneSegment, ObjectSegment, StringSegment, ) -from extensions.ext_database import db -from models.dataset import SegmentAttachmentBinding -from models.model import UploadFile from . import llm_utils from .entities import ( @@ -79,13 +81,16 @@ from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, VariableNotFoundError, ) -from .file_saver import FileSaverImpl, LLMFileSaver +from .file_saver import LLMFileSaver if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file.models import File + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -101,11 +106,12 @@ class LLMNode(Node[LLMNodeData]): _file_outputs: list[File] _llm_file_saver: LLMFileSaver - _credentials_provider: CredentialsProvider - _model_factory: ModelFactory - _model_instance: ModelInstance + _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None + _prompt_message_serializer: PromptMessageSerializerProtocol + _jinja2_template_renderer: Jinja2TemplateRenderer | None + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _default_query_selector: tuple[str, ...] | None def __init__( self, @@ -114,13 +120,16 @@ class LLMNode(Node[LLMNodeData]): graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol, + retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, + default_query_selector: Sequence[str] | None = None, ): super().__init__( id=id, @@ -131,20 +140,15 @@ class LLMNode(Node[LLMNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory - self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer + self._retriever_attachment_loader = retriever_attachment_loader + self._jinja2_template_renderer = jinja2_template_renderer + self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None @classmethod def version(cls) -> str: @@ -190,10 +194,11 @@ class LLMNode(Node[LLMNodeData]): generator = self._fetch_context(node_data=self.node_data) context = None context_files: list[File] = [] - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event + if generator is not None: + for event in generator: + context = event.context + context_files = event.context_files or [] + yield event if context: node_inputs["#context#"] = context @@ -202,6 +207,10 @@ class LLMNode(Node[LLMNodeData]): # fetch model config model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) model_name = model_instance.model_name model_provider = model_instance.provider model_stop = model_instance.stop @@ -211,15 +220,17 @@ class LLMNode(Node[LLMNodeData]): query: str | None = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template - if not query and ( - query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if ( + not query + and self._default_query_selector + and (query_variable := variable_pool.get(self._default_query_selector)) ): query = query_variable.text prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, - context=context, + context=context or "", memory=memory, model_instance=model_instance, stop=model_stop, @@ -230,7 +241,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, - template_renderer=self._template_renderer, + jinja2_template_renderer=self._jinja2_template_renderer, ) # handle invoke result @@ -238,7 +249,6 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -281,7 +291,7 @@ class LLMNode(Node[LLMNodeData]): process_data = { "model_mode": self.node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=self.node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -349,10 +359,9 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def invoke_llm( *, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, - user_id: str, structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, @@ -363,35 +372,35 @@ class LLMNode(Node[LLMNodeData]): ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: model_parameters = model_instance.parameters invoke_model_parameters = dict(model_parameters) - - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( structured_output=structured_output or {}, ) request_start_time = time.perf_counter() - invoke_result = invoke_llm_with_structured_output( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, + invoke_result = cast( + LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + model_instance.invoke_llm_with_structured_output( + prompt_messages=prompt_messages, + json_schema=output_schema, + model_parameters=invoke_model_parameters, + stop=stop, + stream=True, + ), ) else: request_start_time = time.perf_counter() - invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, + invoke_result = cast( + LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=invoke_model_parameters, + tools=None, + stop=stop, + stream=True, + ), ) return LLMNode.handle_invoke_result( @@ -400,6 +409,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs=file_outputs, node_id=node_id, node_type=node_type, + model_instance=model_instance, reasoning_format=reasoning_format, request_start_time=request_start_time, ) @@ -412,6 +422,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs: list[File], node_id: str, node_type: NodeType, + model_instance: PreparedLLMProtocol | object, reasoning_format: Literal["separated", "tagged"] = "tagged", request_start_time: float | None = None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: @@ -483,8 +494,14 @@ class LLMNode(Node[LLMNodeData]): usage = result.delta.usage if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - except OutputParserError as e: - raise LLMNodeError(f"Failed to parse structured output: {e}") + except Exception as e: + if hasattr(model_instance, "is_structured_output_parse_error") and cast( + PreparedLLMProtocol, model_instance + ).is_structured_output_parse_error(e): + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + if type(e).__name__ == "OutputParserError": + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + raise # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() @@ -687,30 +704,8 @@ class LLMNode(Node[LLMNodeData]): segment_id = retriever_resource.get("segment_id") if not segment_id: continue - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.segment_id == segment_id, - ) - ).all() - if attachments_with_bindings: - for _, upload_file in attachments_with_bindings: - attachment_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.require_dify_context().tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), - ) - context_files.append(attachment_info) + if self._retriever_attachment_loader is not None: + context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) yield RunRetrieverResourceEvent( retriever_resources=original_retriever_resource, context=context_str.strip(), @@ -753,9 +748,9 @@ class LLMNode(Node[LLMNodeData]): *, sys_query: str | None = None, sys_files: Sequence[File], - context: str | None = None, + context: str = "", memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -764,24 +759,186 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - return llm_utils.fetch_prompt_messages( - sys_query=sys_query, - sys_files=sys_files, - context=context, - memory=memory, - model_instance=model_instance, - prompt_template=prompt_template, - stop=stop, - memory_config=memory_config, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - variable_pool=variable_pool, - jinja2_variables=jinja2_variables, - context_files=context_files, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if sys_query: + message = LLMNodeChatModelMessage( + text=sys_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + # For issue #11247 - Check if prompt content is a string or a list + if isinstance(prompt_content, str): + prompt_content = str(prompt_content) + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + + # Add current query to the prompt message + if sys_query: + if isinstance(prompt_content, str): + prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + # The sys_files will be deprecated later + if vision_enabled and sys_files: + file_prompts = [] + for file in sys_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Remove empty messages and filter unsupported content + filtered_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, list): + prompt_message_content: list[PromptMessageContentUnionTypes] = [] + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) + continue + + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_schema.features + ) + ): + continue + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: + prompt_message.content = prompt_message_content[0].data + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue + filtered_prompt_messages.append(prompt_message) + + if len(filtered_prompt_messages) == 0: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping( @@ -825,9 +982,6 @@ class LLMNode(Node[LLMNodeData]): if node_data.vision.enabled: variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if node_data.prompt_config: enable_jinja = False @@ -877,20 +1031,62 @@ class LLMNode(Node[LLMNodeData]): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: - return llm_utils.handle_list_messages( - messages=messages, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail_config, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + template = message.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages @staticmethod def handle_blocking_result( @@ -1027,5 +1223,150 @@ class LLMNode(Node[LLMNodeData]): return self.node_data.retry_config.retry_enabled @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance + + +def _combine_message_content_with_role( + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole +): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + case _: + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None, +): + if not template: + return "" + + jinja2_inputs = {} + for jinja2_variable in jinja2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + if jinja2_template_renderer is None: + raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") + return jinja2_template_renderer.render_template(template, jinja2_inputs) + + +def _calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: + rest_tokens = 2000 + runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters + + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in runtime_model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> Sequence[PromptMessage]: + memory_messages: Sequence[PromptMessage] = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + else: + template_text = template.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/graphon/nodes/llm/protocols.py b/api/graphon/nodes/llm/protocols.py new file mode 100644 index 0000000000..65bfd533d1 --- /dev/null +++ b/api/graphon/nodes/llm/protocols.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol + + +class CredentialsProvider(Protocol): + """Port for loading runtime credentials for a provider/model pair.""" + + def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + """Return credentials for the target provider/model or raise a domain error.""" + ... + + +class ModelFactory(Protocol): + """Port for creating prepared graph-facing LLM runtimes for execution.""" + + def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: + """Create a prepared LLM runtime that is ready for graph execution.""" + ... diff --git a/api/graphon/nodes/llm/runtime_protocols.py b/api/graphon/nodes/llm/runtime_protocols.py new file mode 100644 index 0000000000..dbe415d363 --- /dev/null +++ b/api/graphon/nodes/llm/runtime_protocols.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Protocol + +from graphon.file import File +from graphon.model_runtime.entities import LLMMode, PromptMessage +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity + + +class PreparedLLMProtocol(Protocol): + """A graph-facing LLM runtime with provider-specific setup already applied.""" + + @property + def provider(self) -> str: ... + + @property + def model_name(self) -> str: ... + + @property + def parameters(self) -> Mapping[str, Any]: ... + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: ... + + @property + def stop(self) -> Sequence[str] | None: ... + + def get_model_schema(self) -> AIModelEntity: ... + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def is_structured_output_parse_error(self, error: Exception) -> bool: ... + + +class PromptMessageSerializerProtocol(Protocol): + """Port for converting compiled prompt messages into persisted process data.""" + + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: ... + + +class RetrieverAttachmentLoaderProtocol(Protocol): + """Port for resolving retriever segment attachments into graph file references.""" + + def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/dify_graph/nodes/loop/__init__.py b/api/graphon/nodes/loop/__init__.py similarity index 100% rename from api/dify_graph/nodes/loop/__init__.py rename to api/graphon/nodes/loop/__init__.py diff --git a/api/dify_graph/nodes/loop/entities.py b/api/graphon/nodes/loop/entities.py similarity index 88% rename from api/dify_graph/nodes/loop/entities.py rename to api/graphon/nodes/loop/entities.py index f0bfad5a0f..e7362769e9 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/graphon/nodes/loop/entities.py @@ -3,11 +3,11 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState -from dify_graph.utils.condition.entities import Condition -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base import BaseLoopNodeData, BaseLoopState +from graphon.utils.condition.entities import Condition +from graphon.variables.types import SegmentType _VALID_VAR_TYPE = frozenset( [ diff --git a/api/dify_graph/nodes/loop/loop_end_node.py b/api/graphon/nodes/loop/loop_end_node.py similarity index 60% rename from api/dify_graph/nodes/loop/loop_end_node.py rename to api/graphon/nodes/loop/loop_end_node.py index 0287708fb3..c0562b59c4 100644 --- a/api/dify_graph/nodes/loop/loop_end_node.py +++ b/api/graphon/nodes/loop/loop_end_node.py @@ -1,7 +1,7 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopEndNodeData +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.loop.entities import LoopEndNodeData class LoopEndNode(Node[LoopEndNodeData]): diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/graphon/nodes/loop/loop_node.py similarity index 91% rename from api/dify_graph/nodes/loop/loop_node.py rename to api/graphon/nodes/loop/loop_node.py index 3c546ffa23..d574e9f7ae 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/graphon/nodes/loop/loop_node.py @@ -2,23 +2,24 @@ import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import datetime +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Literal, cast -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import ( LoopFailedEvent, LoopNextEvent, LoopStartedEvent, @@ -27,18 +28,17 @@ from dify_graph.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from dify_graph.utils.condition.processor import ConditionProcessor -from dify_graph.variables import Segment, SegmentType -from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable -from libs.datetime_utils import naive_utc_now +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData +from graphon.utils.condition.processor import ConditionProcessor +from graphon.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable if TYPE_CHECKING: - from dify_graph.graph_engine import GraphEngine + from graphon.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): @@ -91,7 +91,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value - start_at = naive_utc_now() + start_at = datetime.now(UTC).replace(tzinfo=None) condition_processor = ConditionProcessor() loop_duration_map: dict[str, float] = {} @@ -124,10 +124,13 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): self._clear_loop_subgraph_variables(loop_node_ids) graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - loop_start_time = naive_utc_now() - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + loop_start_time = datetime.now(UTC).replace(tzinfo=None) + try: + reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + finally: + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) # Track loop duration - loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds() + loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() # Accumulate outputs from the sub-graph's response nodes for key, value in graph_engine.graph_runtime_state.outputs.items(): @@ -142,9 +145,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # For other outputs, just update self.graph_runtime_state.set_output(key, value) - # Accumulate usage from the sub-graph execution - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -256,6 +256,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): yield event if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: reach_break_node = True + if isinstance(event, GraphRunAbortedEvent): + raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) if isinstance(event, GraphRunFailedEvent): raise Exception(event.error) @@ -409,8 +411,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return build_segment_with_type(var_type, value) def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -420,16 +421,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): call_depth=self.workflow_call_depth, ) - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, ) diff --git a/api/dify_graph/nodes/loop/loop_start_node.py b/api/graphon/nodes/loop/loop_start_node.py similarity index 60% rename from api/dify_graph/nodes/loop/loop_start_node.py rename to api/graphon/nodes/loop/loop_start_node.py index e171b4df2f..2b17054ae2 100644 --- a/api/dify_graph/nodes/loop/loop_start_node.py +++ b/api/graphon/nodes/loop/loop_start_node.py @@ -1,7 +1,7 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopStartNodeData +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.loop.entities import LoopStartNodeData class LoopStartNode(Node[LoopStartNodeData]): diff --git a/api/dify_graph/nodes/parameter_extractor/__init__.py b/api/graphon/nodes/parameter_extractor/__init__.py similarity index 100% rename from api/dify_graph/nodes/parameter_extractor/__init__.py rename to api/graphon/nodes/parameter_extractor/__init__.py diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/graphon/nodes/parameter_extractor/entities.py similarity index 93% rename from api/dify_graph/nodes/parameter_extractor/entities.py rename to api/graphon/nodes/parameter_extractor/entities.py index 2fb042c16c..8fda1b9e79 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/graphon/nodes/parameter_extractor/entities.py @@ -7,11 +7,11 @@ from pydantic import ( field_validator, ) -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig +from graphon.prompt_entities import MemoryConfig +from graphon.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" _OLD_SELECT_TYPE_NAME = "select" diff --git a/api/dify_graph/nodes/parameter_extractor/exc.py b/api/graphon/nodes/parameter_extractor/exc.py similarity index 97% rename from api/dify_graph/nodes/parameter_extractor/exc.py rename to api/graphon/nodes/parameter_extractor/exc.py index c25b809d1c..faa90313c1 100644 --- a/api/dify_graph/nodes/parameter_extractor/exc.py +++ b/api/graphon/nodes/parameter_extractor/exc.py @@ -1,6 +1,6 @@ from typing import Any -from dify_graph.variables.types import SegmentType +from graphon.variables.types import SegmentType class ParameterExtractorNodeError(ValueError): diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py similarity index 83% rename from api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py rename to api/graphon/nodes/parameter_extractor/parameter_extractor_node.py index 3913a27697..25379e325c 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,21 +5,16 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.model_manager import ModelInstance -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import File +from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -27,17 +22,18 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm import llm_utils -from dify_graph.runtime import VariablePool -from dify_graph.variables.types import ArrayValidation, SegmentType -from factories.variable_factory import build_segment_with_type +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import NodeRunResult +from graphon.nodes.base import variable_template_parser +from graphon.nodes.base.node import Node +from graphon.nodes.llm import LLMNode, llm_utils +from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol +from graphon.runtime import VariablePool +from graphon.variables import build_segment_with_type +from graphon.variables.types import ArrayValidation, SegmentType from .entities import ParameterExtractorNodeData from .exc import ( @@ -65,9 +61,8 @@ from .prompts import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState def extract_json(text): @@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - _model_instance: ModelInstance - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" + _model_instance: PreparedLLMProtocol + _prompt_message_serializer: PromptMessageSerializerProtocol _memory: PromptMessageMemory | None def __init__( @@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None = None, + prompt_message_serializer: PromptMessageSerializerProtocol, ) -> None: super().__init__( id=id, @@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory self._model_instance = model_instance + self._prompt_message_serializer = prompt_message_serializer self._memory = memory @classmethod @@ -164,13 +159,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) model_instance = self._model_instance - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc + if model_schema.model_type != ModelType.LLM: + raise InvalidModelTypeError("Model is not a Large Language Model") memory = self._memory if ( @@ -210,8 +208,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages + "prompts": self._prompt_message_serializer.serialize( + model_mode=node_data.model.mode, + prompt_messages=prompt_messages, ), "usage": None, "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), @@ -287,18 +286,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def _invoke( self, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], - stop: Sequence[str], + stop: Sequence[str] | None, ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools, - stop=list(stop), - stream=False, - user=self.require_dify_context().user_id, + invoke_result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=dict(model_instance.parameters), + tools=tools or None, + stop=stop, + stream=False, + ), ) # handle invoke result @@ -317,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -329,7 +330,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): content=query, structure=json.dumps(node_data.get_parameter_json_schema()) ) - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -340,15 +340,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_function_calling_prompt_template( node_data, query, variable_pool, memory, rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -405,7 +401,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -413,9 +409,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate prompt engineering prompt. """ - model_mode = ModelMode(data.model.mode) - - if model_mode == ModelMode.COMPLETION: + if data.model.mode == LLMMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( node_data=data, query=query, @@ -425,7 +419,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - elif model_mode == ModelMode.CHAT: + if data.model.mode == LLMMode.CHAT: return self._generate_prompt_engineering_chat_prompt( node_data=data, query=query, @@ -435,15 +429,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - else: - raise InvalidModelModeError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") def _generate_prompt_engineering_completion_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -451,7 +444,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate completion prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -462,27 +454,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, - query="", - files=files, - context="", - memory_config=node_data.memory, - # AdvancedPromptTransform is still typed against TokenBufferMemory. - memory=cast(Any, memory), + return self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) - return prompt_messages - def _generate_prompt_engineering_chat_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -490,7 +475,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate chat prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -508,15 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): max_token_limit=rest_token, ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -717,8 +697,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage]: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -727,15 +706,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -744,8 +722,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -754,64 +731,54 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( + if node_data.model.mode == LLMMode.COMPLETION: + return LLMNodeCompletionModelPromptTemplate( text=COMPLETION_GENERATE_JSON_PROMPT.format( histories=memory_str, text=input_text, instruction=instruction ) .replace("{γγγ", "") .replace("}γγγ", "") + .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), ) - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _calculate_rest_token( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=[], + vision_enabled=False, + context=context, ) rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - curr_message_tokens = ( - model_type_instance.get_num_tokens( - model_instance.model_name, model_instance.credentials, prompt_messages - ) - + 1000 - ) # add 1000 to ensure tool call messages + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 max_tokens = 0 for parameter_rule in model_schema.parameter_rules: @@ -828,8 +795,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens + def _compile_prompt_messages( + self, + *, + model_instance: PreparedLLMProtocol, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + files: Sequence[File], + vision_enabled: bool, + context: str | None = "", + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> list[PromptMessage]: + prompt_messages, _ = LLMNode.fetch_prompt_messages( + sys_query="", + sys_files=files, + context=context or "", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=model_instance.stop, + memory_config=None, + vision_enabled=vision_enabled, + vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + return list(prompt_messages) + @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod diff --git a/api/dify_graph/nodes/parameter_extractor/prompts.py b/api/graphon/nodes/parameter_extractor/prompts.py similarity index 100% rename from api/dify_graph/nodes/parameter_extractor/prompts.py rename to api/graphon/nodes/parameter_extractor/prompts.py diff --git a/api/dify_graph/nodes/protocols.py b/api/graphon/nodes/protocols.py similarity index 81% rename from api/dify_graph/nodes/protocols.py rename to api/graphon/nodes/protocols.py index 62d3bcdca1..4b050c113c 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/graphon/nodes/protocols.py @@ -1,10 +1,9 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Protocol import httpx -from dify_graph.file import File -from dify_graph.file.models import ToolFile +from graphon.file import File class HttpClientProtocol(Protocol): @@ -35,12 +34,13 @@ class ToolFileManagerProtocol(Protocol): def create_file_by_raw( self, *, - user_id: str, - tenant_id: str, - conversation_id: str | None, file_binary: bytes, mimetype: str, filename: str | None = None, ) -> Any: ... - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... + + +class FileReferenceFactoryProtocol(Protocol): + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/dify_graph/nodes/question_classifier/__init__.py b/api/graphon/nodes/question_classifier/__init__.py similarity index 100% rename from api/dify_graph/nodes/question_classifier/__init__.py rename to api/graphon/nodes/question_classifier/__init__.py diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/graphon/nodes/question_classifier/entities.py similarity index 76% rename from api/dify_graph/nodes/question_classifier/entities.py rename to api/graphon/nodes/question_classifier/entities.py index 0c1601d439..8d5f117315 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/graphon/nodes/question_classifier/entities.py @@ -1,9 +1,9 @@ from pydantic import BaseModel, Field -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm import ModelConfig, VisionConfig +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm import ModelConfig, VisionConfig +from graphon.prompt_entities import MemoryConfig class ClassConfig(BaseModel): diff --git a/api/dify_graph/nodes/question_classifier/exc.py b/api/graphon/nodes/question_classifier/exc.py similarity index 100% rename from api/dify_graph/nodes/question_classifier/exc.py rename to api/graphon/nodes/question_classifier/exc.py diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/graphon/nodes/question_classifier/question_classifier_node.py similarity index 85% rename from api/dify_graph/nodes/question_classifier/question_classifier_node.py rename to api/graphon/nodes/question_classifier/question_classifier_node.py index 59d0a2a4d8..a30ffbb149 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/graphon/nodes/question_classifier/question_classifier_node.py @@ -3,34 +3,32 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.model_manager import ModelInstance -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm import ( +from graphon.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ModelInvokeCompletedEvent, NodeRunResult +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.llm import ( LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils, ) -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from libs.json_in_md_parser import parse_and_check_json_markdown +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol +from graphon.nodes.protocols import HttpClientProtocol +from graphon.template_rendering import Jinja2TemplateRenderer +from graphon.utils.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData from .exc import InvalidModelTypeError @@ -45,8 +43,14 @@ from .template_prompts import ( ) if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file.models import File + from graphon.runtime import GraphRuntimeState + + +class _PassthroughPromptMessageSerializer: + def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: + _ = model_mode + return list(prompt_messages) class QuestionClassifierNode(Node[QuestionClassifierNodeData]): @@ -55,11 +59,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _file_outputs: list["File"] _llm_file_saver: LLMFileSaver - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _model_instance: ModelInstance + _prompt_message_serializer: PromptMessageSerializerProtocol + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _template_renderer: Jinja2TemplateRenderer def __init__( self, @@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, + template_renderer: Jinja2TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol | None = None, ): super().__init__( id=id, @@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() @classmethod def version(cls): @@ -114,6 +111,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variables = {"query": query} # fetch model instance model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) memory = self._memory # fetch instruction node_data.instruction = node_data.instruction or "" @@ -169,7 +170,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, @@ -205,7 +205,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): category_id = category_id_result process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -247,7 +247,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod @@ -285,7 +285,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) @@ -295,7 +295,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): prompt_template=prompt_template, sys_query="", sys_files=[], - context=context, + context=context or "", memory=None, model_instance=model_instance, stop=model_instance.stop, @@ -334,7 +334,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): memory: PromptMessageMemory | None, max_token_limit: int = 2000, ): - model_mode = ModelMode(node_data.model.mode) + model_mode = LLMMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: @@ -350,7 +350,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: + if model_mode == LLMMode.CHAT: system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) @@ -381,7 +381,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) prompt_messages.append(user_prompt_message_3) return prompt_messages - elif model_mode == ModelMode.COMPLETION: + elif model_mode == LLMMode.COMPLETION: return LLMNodeCompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, diff --git a/api/dify_graph/nodes/question_classifier/template_prompts.py b/api/graphon/nodes/question_classifier/template_prompts.py similarity index 100% rename from api/dify_graph/nodes/question_classifier/template_prompts.py rename to api/graphon/nodes/question_classifier/template_prompts.py diff --git a/api/graphon/nodes/runtime.py b/api/graphon/nodes/runtime.py new file mode 100644 index 0000000000..650299898c --- /dev/null +++ b/api/graphon/nodes/runtime.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from datetime import datetime +from typing import TYPE_CHECKING, Any, Protocol + +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) + +if TYPE_CHECKING: + from graphon.nodes.human_input.entities import HumanInputNodeData + from graphon.nodes.human_input.enums import HumanInputFormStatus + from graphon.nodes.tool.entities import ToolNodeData + from graphon.runtime import VariablePool + + +class ToolNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter owned by `core.workflow` and consumed by `graphon`. + + The graph package depends only on these DTOs and lets the workflow layer + translate between graph-owned abstractions and `core.tools` internals. + """ + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool: VariablePool | None, + ) -> ToolRuntimeHandle: ... + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: ... + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: ... + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: ... + + def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... + + +class HumanInputNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter for human-input runtime persistence and delivery.""" + + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: ... + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: ... + + +class HumanInputFormStateProtocol(Protocol): + @property + def id(self) -> str: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... diff --git a/api/dify_graph/nodes/start/__init__.py b/api/graphon/nodes/start/__init__.py similarity index 100% rename from api/dify_graph/nodes/start/__init__.py rename to api/graphon/nodes/start/__init__.py diff --git a/api/dify_graph/nodes/start/entities.py b/api/graphon/nodes/start/entities.py similarity index 58% rename from api/dify_graph/nodes/start/entities.py rename to api/graphon/nodes/start/entities.py index 92ebd1a2ec..7df62e1b2b 100644 --- a/api/dify_graph/nodes/start/entities.py +++ b/api/graphon/nodes/start/entities.py @@ -2,9 +2,9 @@ from collections.abc import Sequence from pydantic import Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.input_entities import VariableEntity +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.variables.input_entities import VariableEntity class StartNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/start/start_node.py b/api/graphon/nodes/start/start_node.py similarity index 69% rename from api/dify_graph/nodes/start/start_node.py rename to api/graphon/nodes/start/start_node.py index 5e6055ea34..cb3f4c1e7d 100644 --- a/api/dify_graph/nodes/start/start_node.py +++ b/api/graphon/nodes/start/start_node.py @@ -2,12 +2,11 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.variables.input_entities import VariableEntityType +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.start.entities import StartNodeData +from graphon.variables.input_entities import VariableEntityType class StartNode(Node[StartNodeData]): @@ -19,15 +18,10 @@ class StartNode(Node[StartNodeData]): return "1" def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) self._validate_and_normalize_json_object_inputs(node_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() - - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - outputs = dict(node_inputs) + outputs = dict(self.graph_runtime_state.variable_pool.flatten(unprefixed_node_id=self.id)) + outputs.update(node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/dify_graph/nodes/template_transform/__init__.py b/api/graphon/nodes/template_transform/__init__.py similarity index 100% rename from api/dify_graph/nodes/template_transform/__init__.py rename to api/graphon/nodes/template_transform/__init__.py diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/graphon/nodes/template_transform/entities.py similarity index 54% rename from api/dify_graph/nodes/template_transform/entities.py rename to api/graphon/nodes/template_transform/entities.py index ac29239958..a27a57f34f 100644 --- a/api/dify_graph/nodes/template_transform/entities.py +++ b/api/graphon/nodes/template_transform/entities.py @@ -1,6 +1,6 @@ -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.entities import VariableSelector class TemplateTransformNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/graphon/nodes/template_transform/template_transform_node.py similarity index 57% rename from api/dify_graph/nodes/template_transform/template_transform_node.py rename to api/graphon/nodes/template_transform/template_transform_node.py index dc6fce2b0a..4206fb0c1a 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/graphon/nodes/template_transform/template_transform_node.py @@ -1,26 +1,27 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData -from dify_graph.nodes.template_transform.template_renderer import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.template_rendering import ( Jinja2TemplateRenderer, TemplateRenderError, ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _template_renderer: Jinja2TemplateRenderer + _jinja2_template_renderer: Jinja2TemplateRenderer _max_output_length: int def __init__( @@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - template_renderer: Jinja2TemplateRenderer, + jinja2_template_renderer: Jinja2TemplateRenderer, max_output_length: int | None = None, ) -> None: super().__init__( @@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._template_renderer = template_renderer + self._jinja2_template_renderer = jinja2_template_renderer if max_output_length is not None and max_output_length <= 0: raise ValueError("max_output_length must be a positive integer") @@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - rendered = self._template_renderer.render_template(self.node_data.template, variables) + rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData | Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } + _ = graph_config + raw_variables = ( + node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) + ) + variable_mapping: dict[str, Sequence[str]] = {} + for variable_selector in raw_variables: + if isinstance(variable_selector, VariableSelector): + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector + continue + + if not isinstance(variable_selector, Mapping): + continue + + variable = variable_selector.get("variable") + value_selector = variable_selector.get("value_selector") + if ( + isinstance(variable, str) + and isinstance(value_selector, Sequence) + and all(isinstance(selector_part, str) for selector_part in value_selector) + ): + variable_mapping[node_id + "." + variable] = list(value_selector) + + return variable_mapping diff --git a/api/dify_graph/nodes/tool/__init__.py b/api/graphon/nodes/tool/__init__.py similarity index 100% rename from api/dify_graph/nodes/tool/__init__.py rename to api/graphon/nodes/tool/__init__.py diff --git a/api/dify_graph/nodes/tool/entities.py b/api/graphon/nodes/tool/entities.py similarity index 88% rename from api/dify_graph/nodes/tool/entities.py rename to api/graphon/nodes/tool/entities.py index b041ee66fd..54e6048033 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/graphon/nodes/tool/entities.py @@ -1,11 +1,25 @@ +from enum import StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType + + +class ToolProviderType(StrEnum): + """ + Graph-owned enum for persisted tool provider kinds. + """ + + PLUGIN = auto() + BUILT_IN = "builtin" + WORKFLOW = auto() + API = auto() + APP = auto() + DATASET_RETRIEVAL = "dataset-retrieval" + MCP = auto() class ToolEntity(BaseModel): diff --git a/api/dify_graph/nodes/tool/exc.py b/api/graphon/nodes/tool/exc.py similarity index 53% rename from api/dify_graph/nodes/tool/exc.py rename to api/graphon/nodes/tool/exc.py index 7212e8bfc0..1a309e1084 100644 --- a/api/dify_graph/nodes/tool/exc.py +++ b/api/graphon/nodes/tool/exc.py @@ -4,6 +4,18 @@ class ToolNodeError(ValueError): pass +class ToolRuntimeResolutionError(ToolNodeError): + """Raised when the workflow layer cannot construct a tool runtime.""" + + pass + + +class ToolRuntimeInvocationError(ToolNodeError): + """Raised when the workflow layer fails while invoking a tool runtime.""" + + pass + + class ToolParameterError(ToolNodeError): """Exception raised for errors in tool parameters.""" diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/graphon/nodes/tool/tool_node.py similarity index 60% rename from api/dify_graph/nodes/tool/tool_node.py rename to api/graphon/nodes/tool/tool_node.py index 598f0da92e..57ab8ce5d6 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/graphon/nodes/tool/tool_node.py @@ -1,29 +1,25 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment -from dify_graph.variables.variables import ArrayAnyVariable -from factories import file_factory -from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.runtime import ToolNodeRuntimeProtocol +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from graphon.variables.segments import ArrayFileSegment from .entities import ToolNodeData from .exc import ( @@ -33,8 +29,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool class ToolNode(Node[ToolNodeData]): @@ -52,6 +48,7 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state: "GraphRuntimeState", *, tool_file_manager_factory: ToolFileManagerProtocol, + runtime: ToolNodeRuntimeProtocol | None = None, ): super().__init__( id=id, @@ -60,6 +57,9 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state=graph_runtime_state, ) self._tool_file_manager_factory = tool_file_manager_factory + if runtime is None: + raise ValueError("runtime is required") + self._runtime = runtime @classmethod def version(cls) -> str: @@ -73,10 +73,6 @@ class ToolNode(Node[ToolNodeData]): """ Run the tool node """ - from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - - dify_ctx = self.require_dify_context() - # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -86,8 +82,6 @@ class ToolNode(Node[ToolNodeData]): # get tool runtime try: - from core.tools.tool_manager import ToolManager - # This is an issue that caused problems before. # Logically, we shouldn't use the node_data.version field for judgment # But for backward compatibility with historical data @@ -95,13 +89,10 @@ class ToolNode(Node[ToolNodeData]): variable_pool: VariablePool | None = None if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = ToolManager.get_workflow_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - self._node_id, - self.node_data, - dify_ctx.invoke_from, - variable_pool, + tool_runtime = self._runtime.get_runtime( + node_id=self._node_id, + node_data=self.node_data, + variable_pool=variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -116,7 +107,7 @@ class ToolNode(Node[ToolNodeData]): return # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -128,18 +119,12 @@ class ToolNode(Node[ToolNodeData]): node_data=self.node_data, for_log=True, ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, + message_stream = self._runtime.invoke( + tool_runtime=tool_runtime, tool_parameters=parameters, - user_id=dify_ctx.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + provider_name=self.node_data.provider_name, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -159,38 +144,16 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) - except ToolInvokeError as e: + except ToolNodeError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}", - error_type=type(e).__name__, - ) - ) - except PluginInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name), - error_type=type(e).__name__, - ) - ) - except PluginDaemonClientSideError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool, error: {e.description}", + error=str(e), error_type=type(e).__name__, ) ) @@ -198,7 +161,7 @@ class ToolNode(Node[ToolNodeData]): def _generate_parameters( self, *, - tool_parameters: Sequence[ToolParameter], + tool_parameters: Sequence[ToolRuntimeParameter], variable_pool: "VariablePool", node_data: ToolNodeData, for_log: bool = False, @@ -207,7 +170,7 @@ class ToolNode(Node[ToolNodeData]): Generate parameters based on the given tool parameters, variable pool, and node data. Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. variable_pool (VariablePool): The variable pool containing the variables. node_data (ToolNodeData): The data associated with the tool node. @@ -240,107 +203,89 @@ class ToolNode(Node[ToolNodeData]): return result - def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - def _transform_message( self, - messages: Generator[ToolInvokeMessage, None, None], + messages: Generator[ToolRuntimeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, node_id: str, - tool_runtime: Tool, + tool_runtime: ToolRuntimeHandle, + **_: Any, ) -> Generator[NodeEventBase, None, LLMUsage]: """ - Convert ToolInvokeMessages into tuple[plain_text, files] + Convert graph-owned tool runtime messages into node outputs. """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - text = "" files: list[File] = [] json: list[dict | list] = [] variables: dict[str, Any] = {} - for message in message_stream: + for message in messages: if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, + ToolRuntimeMessage.MessageType.IMAGE_LINK, + ToolRuntimeMessage.MessageType.BINARY_LINK, + ToolRuntimeMessage.MessageType.IMAGE, }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool message is missing tool_file_id metadata") - tool_file_id = str(url).split("/")[-1].split(".")[0] - - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not found") + if tool_file.mime_type is None: + raise ToolFileError(f"tool file {tool_file_id} is missing mime type") - mapping = { + file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mime_type), "transfer_method": transfer_method, "url": url, } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) + file = self._runtime.build_file_reference(mapping=file_mapping) files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == ToolRuntimeMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool blob message is missing tool_file_id metadata") + _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not exists") - mapping = { + blob_file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, "transfer_method": FileTransferMethod.TOOL_FILE, } - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + files.append(self._runtime.build_file_reference(mapping=blob_file_mapping)) + elif message.type == ToolRuntimeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) text += message.message.text yield StreamChunkEvent( selector=[node_id, "text"], chunk=message.message.text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + elif message.type == ToolRuntimeMessage.MessageType.JSON: + assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) # JSON message handling for tool node if message.message.json_object: json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == ToolRuntimeMessage.MessageType.LINK: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) # Check if this LINK message is a file link file_obj = (message.meta or {}).get("file") @@ -356,8 +301,8 @@ class ToolNode(Node[ToolNodeData]): chunk=stream_text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -374,7 +319,7 @@ class ToolNode(Node[ToolNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == ToolRuntimeMessage.MessageType.FILE: assert message.meta is not None assert isinstance(message.meta, dict) # Validate that meta contains a 'file' key @@ -385,38 +330,16 @@ class ToolNode(Node[ToolNodeData]): if not isinstance(message.meta["file"], File): raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) + elif message.type == ToolRuntimeMessage.MessageType.LOG: + assert isinstance(message.message, ToolRuntimeMessage.LogMessage) if message.message.metadata: icon = tool_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - + icon, icon_dark = self._runtime.resolve_provider_icons( + provider_name=dict_metadata["provider"], + default_icon=icon, + ) dict_metadata["icon"] = icon dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata @@ -446,7 +369,7 @@ class ToolNode(Node[ToolNodeData]): is_final=True, ) - usage = self._extract_tool_usage(tool_runtime) + usage = self._runtime.get_usage(tool_runtime=tool_runtime) metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, @@ -468,21 +391,6 @@ class ToolNode(Node[ToolNodeData]): return usage - @staticmethod - def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - # Avoid importing WorkflowTool at module import time; rely on duck typing - # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. - latest = getattr(tool_runtime, "latest_usage", None) - # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects - # for any name, so we must type-check here. - if isinstance(latest, LLMUsage): - return latest - if isinstance(latest, dict): - # Allow dict payloads from external runtimes - return LLMUsage.model_validate(latest) - # Fallback to empty usage when attribute is missing or not a valid payload - return LLMUsage.empty_usage() - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/graphon/nodes/tool_runtime_entities.py b/api/graphon/nodes/tool_runtime_entities.py new file mode 100644 index 0000000000..5bb0c16573 --- /dev/null +++ b/api/graphon/nodes/tool_runtime_entities.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum, auto +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class _ToolRuntimeModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeHandle: + """Opaque graph-owned handle for a workflow-layer tool runtime. + + Workflow-specific execution context must stay behind `raw` so the graph + contract does not absorb application-owned concepts. + """ + + raw: object + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeParameter: + """Graph-owned parameter shape used by tool nodes.""" + + name: str + required: bool = False + + +class ToolRuntimeMessage(_ToolRuntimeModel): + """Graph-owned tool invocation message DTO.""" + + class TextMessage(_ToolRuntimeModel): + text: str + + class JsonMessage(_ToolRuntimeModel): + json_object: dict[str, Any] | list[Any] + suppress_output: bool = Field(default=False) + + class BlobMessage(_ToolRuntimeModel): + blob: bytes + + class BlobChunkMessage(_ToolRuntimeModel): + id: str + sequence: int + total_length: int + blob: bytes + end: bool + + class FileMessage(_ToolRuntimeModel): + file_marker: str = Field(default="file_marker") + + class VariableMessage(_ToolRuntimeModel): + variable_name: str + variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None + stream: bool = Field(default=False) + + class LogMessage(_ToolRuntimeModel): + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() + + id: str + label: str + parent_id: str | None = None + error: str | None = None + status: LogStatus + data: dict[str, Any] + metadata: dict[str, Any] = Field(default_factory=dict) + + class RetrieverResourceMessage(_ToolRuntimeModel): + retriever_resources: list[dict[str, Any]] + context: str + + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() + + type: MessageType = MessageType.TEXT + message: ( + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage + ) + meta: dict[str, Any] | None = None diff --git a/api/dify_graph/nodes/variable_aggregator/__init__.py b/api/graphon/nodes/variable_aggregator/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_aggregator/__init__.py rename to api/graphon/nodes/variable_aggregator/__init__.py diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/graphon/nodes/variable_aggregator/entities.py similarity index 77% rename from api/dify_graph/nodes/variable_aggregator/entities.py rename to api/graphon/nodes/variable_aggregator/entities.py index 4779ebd9a9..136fd28f8c 100644 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ b/api/graphon/nodes/variable_aggregator/entities.py @@ -1,8 +1,8 @@ from pydantic import BaseModel -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.variables.types import SegmentType class AdvancedSettings(BaseModel): diff --git a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py similarity index 81% rename from api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py rename to api/graphon/nodes/variable_aggregator/variable_aggregator_node.py index 7d26de6232..71b221e196 100644 --- a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from dify_graph.variables.segments import Segment +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.variable_aggregator.entities import VariableAggregatorNodeData +from graphon.variables.segments import Segment class VariableAggregatorNode(Node[VariableAggregatorNodeData]): diff --git a/api/dify_graph/utils/__init__.py b/api/graphon/nodes/variable_assigner/__init__.py similarity index 100% rename from api/dify_graph/utils/__init__.py rename to api/graphon/nodes/variable_assigner/__init__.py diff --git a/api/dify_graph/utils/condition/__init__.py b/api/graphon/nodes/variable_assigner/common/__init__.py similarity index 100% rename from api/dify_graph/utils/condition/__init__.py rename to api/graphon/nodes/variable_assigner/common/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/common/exc.py b/api/graphon/nodes/variable_assigner/common/exc.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/common/exc.py rename to api/graphon/nodes/variable_assigner/common/exc.py diff --git a/api/dify_graph/nodes/variable_assigner/common/helpers.py b/api/graphon/nodes/variable_assigner/common/helpers.py similarity index 91% rename from api/dify_graph/nodes/variable_assigner/common/helpers.py rename to api/graphon/nodes/variable_assigner/common/helpers.py index f0b22904a9..4c30e009f2 100644 --- a/api/dify_graph/nodes/variable_assigner/common/helpers.py +++ b/api/graphon/nodes/variable_assigner/common/helpers.py @@ -3,9 +3,9 @@ from typing import Any, TypeVar from pydantic import BaseModel -from dify_graph.variables import Segment -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.types import SegmentType +from graphon.variables import Segment +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables # to minimize risk of collision with user-defined variable names. diff --git a/api/dify_graph/nodes/variable_assigner/v1/__init__.py b/api/graphon/nodes/variable_assigner/v1/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/v1/__init__.py rename to api/graphon/nodes/variable_assigner/v1/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/graphon/nodes/variable_assigner/v1/node.py similarity index 64% rename from api/dify_graph/nodes/variable_assigner/v1/node.py rename to api/graphon/nodes/variable_assigner/v1/node.py index f9b261b191..19ded5f123 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/graphon/nodes/variable_assigner/v1/node.py @@ -1,20 +1,19 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.variable_assigner.common import helpers as common_helpers +from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from graphon.variables import SegmentType, Variable, VariableBase from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: - from dify_graph.runtime import GraphRuntimeState + from graphon.runtime import GraphRuntimeState class VariableAssignerNode(Node[VariableAssignerData]): @@ -56,18 +55,16 @@ class VariableAssignerNode(Node[VariableAssignerData]): node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" mapping[key] = node_data.input_variable_selector return mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) @@ -92,18 +89,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): income_value = SegmentType.get_zero_value(original_variable.value_type) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - # Over write the variable. - self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, + yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "value": income_value.to_object(), + }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, + ) ) diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/graphon/nodes/variable_assigner/v1/node_data.py similarity index 76% rename from api/dify_graph/nodes/variable_assigner/v1/node_data.py rename to api/graphon/nodes/variable_assigner/v1/node_data.py index 57acb29535..4f630bc76c 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ b/api/graphon/nodes/variable_assigner/v1/node_data.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from enum import StrEnum -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class WriteMode(StrEnum): diff --git a/api/dify_graph/nodes/variable_assigner/v2/__init__.py b/api/graphon/nodes/variable_assigner/v2/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/v2/__init__.py rename to api/graphon/nodes/variable_assigner/v2/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/graphon/nodes/variable_assigner/v2/entities.py similarity index 89% rename from api/dify_graph/nodes/variable_assigner/v2/entities.py rename to api/graphon/nodes/variable_assigner/v2/entities.py index 2b2bbe85de..d1c68c8e8c 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ b/api/graphon/nodes/variable_assigner/v2/entities.py @@ -3,8 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from .enums import InputType, Operation diff --git a/api/dify_graph/nodes/variable_assigner/v2/enums.py b/api/graphon/nodes/variable_assigner/v2/enums.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/v2/enums.py rename to api/graphon/nodes/variable_assigner/v2/enums.py diff --git a/api/dify_graph/nodes/variable_assigner/v2/exc.py b/api/graphon/nodes/variable_assigner/v2/exc.py similarity index 93% rename from api/dify_graph/nodes/variable_assigner/v2/exc.py rename to api/graphon/nodes/variable_assigner/v2/exc.py index c50aab8668..90d7648574 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/exc.py +++ b/api/graphon/nodes/variable_assigner/v2/exc.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError from .enums import InputType, Operation diff --git a/api/dify_graph/nodes/variable_assigner/v2/helpers.py b/api/graphon/nodes/variable_assigner/v2/helpers.py similarity index 98% rename from api/dify_graph/nodes/variable_assigner/v2/helpers.py rename to api/graphon/nodes/variable_assigner/v2/helpers.py index 38c69cbe3c..ebc6c79476 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/helpers.py +++ b/api/graphon/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,6 @@ from typing import Any -from dify_graph.variables import SegmentType +from graphon.variables import SegmentType from .enums import Operation diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/graphon/nodes/variable_assigner/v2/node.py similarity index 71% rename from api/dify_graph/nodes/variable_assigner/v2/node.py rename to api/graphon/nodes/variable_assigner/v2/node.py index f04a6b3b80..887bd1b604 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/graphon/nodes/variable_assigner/v2/node.py @@ -1,16 +1,15 @@ import json -from collections.abc import Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Mapping, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, cast -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.variable_assigner.common import helpers as common_helpers +from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from graphon.variables import SegmentType, Variable, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem @@ -24,14 +23,11 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_node_id = item.variable_selector[0] - if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: - return selector_str = ".".join(item.variable_selector) key = f"{node_id}.#{selector_str}#" mapping[key] = item.variable_selector @@ -103,15 +99,18 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): _source_mapping_from_item(var_mapping, node_id, item) return var_mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] + # Preserve intra-node read-after-write behavior without mutating the shared pool + # until the engine processes the emitted VariableUpdatedEvent instances. + working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) try: for item in self.node_data.items: - variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) + variable = working_variable_pool.get(item.variable_selector) # ==================== Validation Part @@ -136,60 +135,64 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) # Get value from variable pool + input_value = item.value if ( item.input_type == InputType.VARIABLE and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} and item.value is not None ): - value = self.graph_runtime_state.variable_pool.get(item.value) + value = working_variable_pool.get(item.value) if value is None: raise VariableNotFoundError(variable_selector=item.value) # Skip if value is NoneSegment if value.value_type == SegmentType.NONE: continue - item.value = value.value + input_value = value.value # If set string / bytes / bytearray to object, try convert string to object. if ( item.operation == Operation.SET and variable.value_type == SegmentType.OBJECT - and isinstance(item.value, str | bytes | bytearray) + and isinstance(input_value, str | bytes | bytearray) ): try: - item.value = json.loads(item.value) + input_value = json.loads(input_value) except json.JSONDecodeError: - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # Check if input value is valid if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=item.value + variable_type=variable.value_type, operation=item.operation, value=input_value ): - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # ==================== Execution Part updated_value = self._handle_item( variable=variable, operation=item.operation, - value=item.value, + value=input_value, ) - variable = variable.model_copy(update={"value": updated_value}) - self.graph_runtime_state.variable_pool.add(variable.selector, variable) - updated_variable_selectors.append(variable.selector) + updated_variable = variable.model_copy(update={"value": updated_value}) + working_variable_pool.add(updated_variable.selector, updated_variable) + updated_variable_selectors.append(updated_variable.selector) except VariableOperatorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + error=str(e), + ) ) + return # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove the duplicated items first. - updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + # remove duplicated items while preserving the first update order. + updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) for selector in updated_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(selector) + variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -197,15 +200,23 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors - if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + if (seg := working_variable_pool.get(selector)) is not None ] process_data = common_helpers.set_updated_variables(process_data, updated_variables) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, + for selector in updated_variable_selectors: + variable = working_variable_pool.get(selector) + if not isinstance(variable, VariableBase): + raise VariableNotFoundError(variable_selector=selector) + yield VariableUpdatedEvent(variable=cast(Variable, variable)) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={}, + ) ) def _handle_item( diff --git a/api/graphon/prompt_entities.py b/api/graphon/prompt_entities.py new file mode 100644 index 0000000000..2b8b106c6c --- /dev/null +++ b/api/graphon/prompt_entities.py @@ -0,0 +1,47 @@ +from typing import Literal + +from pydantic import BaseModel + +from graphon.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """Graph-owned chat prompt template message.""" + + text: str + role: PromptMessageRole + edition_type: Literal["basic", "jinja2"] | None = None + + +class CompletionModelPromptTemplate(BaseModel): + """Graph-owned completion prompt template.""" + + text: str + edition_type: Literal["basic", "jinja2"] | None = None + + +class MemoryConfig(BaseModel): + """Graph-owned memory configuration for prompt assembly.""" + + class RolePrefix(BaseModel): + """Role labels used when serializing completion-model histories.""" + + user: str + assistant: str + + class WindowConfig(BaseModel): + """History windowing controls.""" + + enabled: bool + size: int | None = None + + role_prefix: RolePrefix | None = None + window: WindowConfig + query_prompt_template: str | None = None + + +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/dify_graph/runtime/__init__.py b/api/graphon/runtime/__init__.py similarity index 100% rename from api/dify_graph/runtime/__init__.py rename to api/graphon/runtime/__init__.py diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/graphon/runtime/graph_runtime_state.py similarity index 92% rename from api/dify_graph/runtime/graph_runtime_state.py rename to api/graphon/runtime/graph_runtime_state.py index 41acc6db35..8453830f28 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/graphon/runtime/graph_runtime_state.py @@ -3,6 +3,7 @@ from __future__ import annotations import importlib import json from collections.abc import Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol @@ -10,13 +11,13 @@ from typing import TYPE_CHECKING, Any, ClassVar, Protocol from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder -from dify_graph.enums import NodeExecutionType, NodeState, NodeType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.variable_pool import VariablePool +from graphon.enums import NodeExecutionType, NodeState, NodeType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime.variable_pool import VariablePool if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.entities.pause_reason import PauseReason + from graphon.entities import GraphInitParams + from graphon.entities.pause_reason import PauseReason class ReadyQueueProtocol(Protocol): @@ -51,6 +52,12 @@ class ReadyQueueProtocol(Protocol): ... +class NodeExecutionProtocol(Protocol): + """Structural interface for persisted per-node execution state.""" + + execution_id: str | None + + class GraphExecutionProtocol(Protocol): """Structural interface for graph execution aggregate. @@ -66,6 +73,11 @@ class GraphExecutionProtocol(Protocol): exceptions_count: int pause_reasons: list[PauseReason] + @property + def node_executions(self) -> Mapping[str, NodeExecutionProtocol]: + """Return the persisted node execution state keyed by node id.""" + ... + def start(self) -> None: """Transition execution into the running state.""" ... @@ -142,10 +154,9 @@ class ChildGraphEngineBuilderProtocol(Protocol): *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: ... @@ -211,6 +222,7 @@ class GraphRuntimeState: graph_execution: GraphExecutionProtocol | None = None, response_coordinator: ResponseStreamCoordinatorProtocol | None = None, graph: GraphProtocol | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: self._variable_pool = variable_pool self._start_at = start_at @@ -231,6 +243,9 @@ class GraphRuntimeState: self._ready_queue = ready_queue self._graph_execution = graph_execution self._response_coordinator = response_coordinator + # Application code injects this when worker threads must restore request + # or framework-local state. It is intentionally excluded from snapshots. + self._execution_context = execution_context if execution_context is not None else nullcontext(None) self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() @@ -285,21 +300,19 @@ class GraphRuntimeState: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: + """Create a child graph engine that derives its runtime state from the parent.""" if self._child_engine_builder is None: raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") return self._child_engine_builder.build_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, + parent_graph_runtime_state=self, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) # ------------------------------------------------------------------ @@ -329,6 +342,14 @@ class GraphRuntimeState: self._response_coordinator = self._build_response_coordinator(self._graph) return self._response_coordinator + @property + def execution_context(self) -> AbstractContextManager[object]: + return self._execution_context + + @execution_context.setter + def execution_context(self, value: AbstractContextManager[object] | None) -> None: + self._execution_context = value if value is not None else nullcontext(None) + # ------------------------------------------------------------------ # Scalar state # ------------------------------------------------------------------ @@ -485,13 +506,13 @@ class GraphRuntimeState: # ------------------------------------------------------------------ def _build_ready_queue(self) -> ReadyQueueProtocol: # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("dify_graph.graph_engine.ready_queue") + module = importlib.import_module("graphon.graph_engine.ready_queue") in_memory_cls = module.InMemoryReadyQueue return in_memory_cls() def _build_graph_execution(self) -> GraphExecutionProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution") + module = importlib.import_module("graphon.graph_engine.domain.graph_execution") graph_execution_cls = module.GraphExecution workflow_id = self._pending_graph_execution_workflow_id or "" self._pending_graph_execution_workflow_id = None @@ -499,7 +520,7 @@ class GraphRuntimeState: def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.response_coordinator") + module = importlib.import_module("graphon.graph_engine.response_coordinator") coordinator_cls = module.ResponseStreamCoordinator return coordinator_cls(variable_pool=self.variable_pool, graph=graph) diff --git a/api/dify_graph/runtime/graph_runtime_state_protocol.py b/api/graphon/runtime/graph_runtime_state_protocol.py similarity index 89% rename from api/dify_graph/runtime/graph_runtime_state_protocol.py rename to api/graphon/runtime/graph_runtime_state_protocol.py index 7e55ece3f1..856625a5d3 100644 --- a/api/dify_graph/runtime/graph_runtime_state_protocol.py +++ b/api/graphon/runtime/graph_runtime_state_protocol.py @@ -1,9 +1,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.variables.segments import Segment class ReadOnlyVariablePool(Protocol): @@ -31,9 +30,6 @@ class ReadOnlyGraphRuntimeState(Protocol): All methods return defensive copies to ensure immutability. """ - @property - def system_variable(self) -> SystemVariableReadOnlyView: ... - @property def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" diff --git a/api/dify_graph/runtime/read_only_wrappers.py b/api/graphon/runtime/read_only_wrappers.py similarity index 88% rename from api/dify_graph/runtime/read_only_wrappers.py rename to api/graphon/runtime/read_only_wrappers.py index ca06d88c3d..aaef255204 100644 --- a/api/dify_graph/runtime/read_only_wrappers.py +++ b/api/graphon/runtime/read_only_wrappers.py @@ -4,9 +4,8 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState from .variable_pool import VariablePool @@ -43,10 +42,6 @@ class ReadOnlyGraphRuntimeStateWrapper: self._state = state self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - @property - def system_variable(self) -> SystemVariableReadOnlyView: - return self._state.variable_pool.system_variables.as_view() - @property def variable_pool(self) -> ReadOnlyVariablePoolWrapper: return self._variable_pool_wrapper diff --git a/api/dify_graph/runtime/variable_pool.py b/api/graphon/runtime/variable_pool.py similarity index 63% rename from api/dify_graph/runtime/variable_pool.py rename to api/graphon/runtime/variable_pool.py index e3ef6a2897..b44d1a8abe 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/graphon/runtime/variable_pool.py @@ -6,84 +6,84 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Annotated, Any, Union, cast -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from dify_graph.file import File, FileAttribute, file_manager -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import Segment, SegmentGroup, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import FileSegment, ObjectSegment -from dify_graph.variables.variables import RAGPipelineVariableInput, Variable -from factories import variable_factory +from graphon.file import File, FileAttribute, file_manager +from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import FileSegment, ObjectSegment +from graphon.variables.variables import RAGPipelineVariableInput, Variable VariableValue = Union[str, int, float, dict[str, object], list[object], File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") +def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: + return defaultdict(dict) + + class VariablePool(BaseModel): + _SYSTEM_VARIABLE_NODE_ID = "sys" + _ENVIRONMENT_VARIABLE_NODE_ID = "env" + _CONVERSATION_VARIABLE_NODE_ID = "conversation" + _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" + # Variable dictionary is a dictionary for looking up variables by their selector. # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", - default=defaultdict(dict), + default_factory=_default_variable_dictionary, ) + system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) + user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) - # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. - user_inputs: Mapping[str, Any] = Field( - description="User inputs", - default_factory=dict, - ) - system_variables: SystemVariable = Field( - description="System variables", - default_factory=SystemVariable.default, - ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list[Variable], - ) - conversation_variables: Sequence[Variable] = Field( - description="Conversation variables.", - default_factory=list[Variable], - ) - rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( - description="RAG pipeline variables.", - default_factory=list, - ) + @model_validator(mode="after") + def _load_legacy_bootstrap_inputs(self) -> VariablePool: + """ + Accept legacy constructor kwargs that still appear throughout the workflow + layer while keeping serialized state focused on `variable_dictionary`. + """ - def model_post_init(self, context: Any, /): - # Create a mapping from field names to SystemVariableKey enum values - self._add_system_variables(self.system_variables) - # Add environment variables to the variable pool - for var in self.environment_variables: - self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool. When restoring from a serialized - # snapshot, `variable_dictionary` already carries the latest runtime values. - # In that case, keep existing entries instead of overwriting them with the - # bootstrap list. - for var in self.conversation_variables: - selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) - if self._has(selector): - continue - self.add(selector, var) - # Add rag pipeline variables to the variable pool - if self.rag_pipeline_variables: - rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) - for rag_var in self.rag_pipeline_variables: - node_id = rag_var.variable.belong_to_node_id - key = rag_var.variable.variable - value = rag_var.value - rag_pipeline_variables_map[node_id][key] = value - for key, value in rag_pipeline_variables_map.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) + self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) + self._ingest_legacy_rag_variables(self.rag_pipeline_variables) + + # These kwargs are accepted for compatibility but should not affect the + # stable serialized form or model equality. + self.system_variables = () + self.environment_variables = () + self.conversation_variables = () + self.rag_pipeline_variables = () + self.user_inputs = {} + return self + + def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: + for variable in variables: + selector = [node_id, variable.name] + normalized_variable = variable + if list(variable.selector) != selector: + normalized_variable = variable.model_copy(update={"selector": selector}) + self.add(normalized_variable.selector, normalized_variable) + + def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: + if not rag_pipeline_variables: + return + + values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_variable_input in rag_pipeline_variables: + values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( + rag_variable_input.value + ) + + for node_id, value in values_by_node_id.items(): + self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) def add(self, selector: Sequence[str], value: Any, /): """ @@ -114,10 +114,10 @@ class VariablePool(BaseModel): if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): - variable = variable_factory.segment_to_variable(segment=value, selector=selector) + variable = segment_to_variable(segment=value, selector=selector) else: - segment = variable_factory.build_segment(value) - variable = variable_factory.segment_to_variable(segment=segment, selector=selector) + segment = build_segment(value) + variable = segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) # Based on the definition of `Variable`, @@ -180,7 +180,7 @@ class VariablePool(BaseModel): return None attr = FileAttribute(attr) attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return variable_factory.build_segment(attr_value) + return build_segment(attr_value) # Navigate through nested attributes result: Any = segment @@ -191,7 +191,7 @@ class VariablePool(BaseModel): return None # Return result as Segment - return result if isinstance(result, Segment) else variable_factory.build_segment(result) + return result if isinstance(result, Segment) else build_segment(result) def _extract_value(self, obj: Any): """Extract the actual value from an ObjectSegment.""" @@ -212,7 +212,7 @@ class VariablePool(BaseModel): """ if not isinstance(obj, dict) or attr not in obj: return None - return variable_factory.build_segment(obj.get(attr)) + return build_segment(obj.get(attr)) def remove(self, selector: Sequence[str], /): """ @@ -239,7 +239,7 @@ class VariablePool(BaseModel): if "." in part and (variable := self.get(part.split("."))): segments.append(variable) else: - segments.append(variable_factory.build_segment(part)) + segments.append(build_segment(part)) return SegmentGroup(value=segments) def get_file(self, selector: Sequence[str], /) -> FileSegment | None: @@ -262,19 +262,18 @@ class VariablePool(BaseModel): return result - def _add_system_variables(self, system_variable: SystemVariable): - sys_var_mapping = system_variable.to_dict() - for key, value in sys_var_mapping.items(): - if value is None: - continue - selector = (SYSTEM_VARIABLE_NODE_ID, key) - # If the system variable already exists, do not add it again. - # This ensures that we can keep the id of the system variables intact. - if self._has(selector): - continue - self.add(selector, value) + def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: + """Return a selector-style snapshot of the entire variable pool.""" + + result: dict[str, object] = {} + for node_id, variables in self.variable_dictionary.items(): + for name, variable in variables.items(): + output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" + result[output_name] = deepcopy(variable.value) + + return result @classmethod def empty(cls) -> VariablePool: """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.default()) + return cls() diff --git a/api/graphon/template_rendering.py b/api/graphon/template_rendering.py new file mode 100644 index 0000000000..0527e58f6d --- /dev/null +++ b/api/graphon/template_rendering.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any + + +class TemplateRenderError(ValueError): + """Raised when rendering a template fails.""" + + +class Jinja2TemplateRenderer(ABC): + """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" + + @abstractmethod + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + """Render the template into plain text.""" + raise NotImplementedError diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py b/api/graphon/utils/__init__.py similarity index 100% rename from api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py rename to api/graphon/utils/__init__.py diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__init__.py b/api/graphon/utils/condition/__init__.py similarity index 100% rename from api/tests/unit_tests/dify_graph/model_runtime/__init__.py rename to api/graphon/utils/condition/__init__.py diff --git a/api/dify_graph/utils/condition/entities.py b/api/graphon/utils/condition/entities.py similarity index 100% rename from api/dify_graph/utils/condition/entities.py rename to api/graphon/utils/condition/entities.py diff --git a/api/dify_graph/utils/condition/processor.py b/api/graphon/utils/condition/processor.py similarity index 98% rename from api/dify_graph/utils/condition/processor.py rename to api/graphon/utils/condition/processor.py index dea72d96c2..03535927cb 100644 --- a/api/dify_graph/utils/condition/processor.py +++ b/api/graphon/utils/condition/processor.py @@ -2,10 +2,10 @@ import json from collections.abc import Mapping, Sequence from typing import Literal, NamedTuple -from dify_graph.file import FileAttribute, file_manager -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayBooleanSegment, BooleanSegment +from graphon.file import FileAttribute, file_manager +from graphon.runtime import VariablePool +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayBooleanSegment, BooleanSegment from .entities import Condition, SubCondition, SupportedComparisonOperator diff --git a/api/graphon/utils/json_in_md_parser.py b/api/graphon/utils/json_in_md_parser.py new file mode 100644 index 0000000000..4416b4774b --- /dev/null +++ b/api/graphon/utils/json_in_md_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json + + +class OutputParserError(ValueError): + """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" + + +def parse_json_markdown(json_string: str) -> dict | list: + """Extract and parse the first JSON object or array embedded in markdown text.""" + json_string = json_string.strip() + starts = ["```json", "```", "``", "`", "{", "["] + ends = ["```", "``", "`", "}", "]"] + end_index = -1 + start_index = 0 + + for start_marker in starts: + start_index = json_string.find(start_marker) + if start_index != -1: + if json_string[start_index] not in ("{", "["): + start_index += len(start_marker) + break + + if start_index != -1: + for end_marker in ends: + end_index = json_string.rfind(end_marker, start_index) + if end_index != -1: + if json_string[end_index] in ("}", "]"): + end_index += 1 + break + + if start_index == -1 or end_index == -1 or start_index >= end_index: + raise ValueError("could not find json block in the output.") + + extracted_content = json_string[start_index:end_index].strip() + return json.loads(extracted_content) + + +def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as exc: + raise OutputParserError(f"got invalid json object. error: {exc}") from exc + + if isinstance(json_obj, list): + if len(json_obj) == 1 and isinstance(json_obj[0], dict): + json_obj = json_obj[0] + else: + raise OutputParserError(f"got invalid return object. obj:{json_obj}") + + for key in expected_keys: + if key not in json_obj: + raise OutputParserError( + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" + ) + + return json_obj diff --git a/api/dify_graph/variable_loader.py b/api/graphon/variable_loader.py similarity index 82% rename from api/dify_graph/variable_loader.py rename to api/graphon/variable_loader.py index d263450334..03db920d3d 100644 --- a/api/dify_graph/variable_loader.py +++ b/api/graphon/variable_loader.py @@ -2,9 +2,9 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from dify_graph.runtime import VariablePool -from dify_graph.variables import VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH +from graphon.runtime import VariablePool +from graphon.variables import VariableBase +from graphon.variables.consts import SELECTORS_LENGTH class VariableLoader(Protocol): @@ -13,14 +13,6 @@ class VariableLoader(Protocol): A `VariableLoader` is responsible for retrieving additional variables required during the execution of a single node, which are not provided as user inputs. - NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same - application and share the same `app_id`. However, this interface does not enforce that constraint, - and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of - concern and allow for flexible implementations. - - Implementations of `VariableLoader` should almost always have an `app_id` parameter in - their constructor. - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into `WorkflowService.single_step_run`, we may get rid of this interface. """ diff --git a/api/dify_graph/variables/__init__.py b/api/graphon/variables/__init__.py similarity index 83% rename from api/dify_graph/variables/__init__.py rename to api/graphon/variables/__init__.py index be3fc8d97a..e9beb6cb95 100644 --- a/api/dify_graph/variables/__init__.py +++ b/api/graphon/variables/__init__.py @@ -1,3 +1,10 @@ +from .factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, +) from .input_entities import VariableEntity, VariableEntityType from .segment_group import SegmentGroup from .segments import ( @@ -63,8 +70,13 @@ __all__ = [ "SegmentType", "StringSegment", "StringVariable", + "TypeMismatchError", + "UnsupportedSegmentTypeError", "Variable", "VariableBase", "VariableEntity", "VariableEntityType", + "build_segment", + "build_segment_with_type", + "segment_to_variable", ] diff --git a/api/dify_graph/variables/consts.py b/api/graphon/variables/consts.py similarity index 100% rename from api/dify_graph/variables/consts.py rename to api/graphon/variables/consts.py diff --git a/api/dify_graph/variables/exc.py b/api/graphon/variables/exc.py similarity index 100% rename from api/dify_graph/variables/exc.py rename to api/graphon/variables/exc.py diff --git a/api/graphon/variables/factory.py b/api/graphon/variables/factory.py new file mode 100644 index 0000000000..ac693914a7 --- /dev/null +++ b/api/graphon/variables/factory.py @@ -0,0 +1,202 @@ +"""Graph-owned helpers for converting runtime values, segments, and variables. + +These conversions are part of the `graphon` runtime model and must stay +independent from top-level API factory modules so graph nodes and state +containers can operate without importing application-layer packages. +""" + +from collections.abc import Mapping, Sequence +from typing import Any, cast +from uuid import uuid4 + +from graphon.file import File + +from .segments import ( + ArrayAnySegment, + ArrayBooleanSegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArraySegment, + ArrayStringSegment, + BooleanSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType +from .variables import ( + ArrayAnyVariable, + ArrayBooleanVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + BooleanVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + StringVariable, + VariableBase, +) + + +class UnsupportedSegmentTypeError(Exception): + pass + + +class TypeMismatchError(Exception): + pass + + +SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = { + ArrayAnySegment: ArrayAnyVariable, + ArrayBooleanSegment: ArrayBooleanVariable, + ArrayFileSegment: ArrayFileVariable, + ArrayNumberSegment: ArrayNumberVariable, + ArrayObjectSegment: ArrayObjectVariable, + ArrayStringSegment: ArrayStringVariable, + BooleanSegment: BooleanVariable, + FileSegment: FileVariable, + FloatSegment: FloatVariable, + IntegerSegment: IntegerVariable, + NoneSegment: NoneVariable, + ObjectSegment: ObjectVariable, + StringSegment: StringVariable, +} + + +def build_segment(value: Any, /) -> Segment: + """Build a runtime segment from a Python value.""" + if value is None: + return NoneSegment() + if isinstance(value, Segment): + return value + if isinstance(value, str): + return StringSegment(value=value) + if isinstance(value, bool): + return BooleanSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) + if isinstance(value, list): + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if all(isinstance(item, ArraySegment) for item in items): + return ArrayAnySegment(value=value) + if len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return ArrayNumberSegment(value=value) + case SegmentType.BOOLEAN: + return ArrayBooleanSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) + case _: + raise ValueError(f"not supported value {value}") + raise ValueError(f"not supported value {value}") + + +_SEGMENT_FACTORY: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.BOOLEAN: BooleanSegment, + SegmentType.OBJECT: ObjectSegment, + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, + SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, +} + + +def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: + """Build a segment while enforcing compatibility with the expected runtime type.""" + if value is None: + if segment_type == SegmentType.NONE: + return NoneSegment() + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") + + if isinstance(value, list) and len(value) == 0: + if segment_type == SegmentType.ARRAY_ANY: + return ArrayAnySegment(value=value) + if segment_type == SegmentType.ARRAY_STRING: + return ArrayStringSegment(value=value) + if segment_type == SegmentType.ARRAY_BOOLEAN: + return ArrayBooleanSegment(value=value) + if segment_type == SegmentType.ARRAY_NUMBER: + return ArrayNumberSegment(value=value) + if segment_type == SegmentType.ARRAY_OBJECT: + return ArrayObjectSegment(value=value) + if segment_type == SegmentType.ARRAY_FILE: + return ArrayFileSegment(value=value) + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + + inferred_type = SegmentType.infer_segment_type(value) + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) + if inferred_type == segment_type: + segment_class = _SEGMENT_FACTORY[segment_type] + return segment_class(value_type=segment_type, value=value) + if segment_type == SegmentType.NUMBER and inferred_type in (SegmentType.INTEGER, SegmentType.FLOAT): + segment_class = _SEGMENT_FACTORY[inferred_type] + return segment_class(value_type=inferred_type, value=value) + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") + + +def segment_to_variable( + *, + segment: Segment, + selector: Sequence[str], + id: str | None = None, + name: str | None = None, + description: str = "", +) -> VariableBase: + """Convert a runtime segment into a runtime variable for storage in the pool.""" + if isinstance(segment, VariableBase): + return segment + name = name or selector[-1] + id = id or str(uuid4()) + + segment_type = type(segment) + if segment_type not in SEGMENT_TO_VARIABLE_MAP: + raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") + + variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] + return cast( + VariableBase, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=list(selector), + ), + ) diff --git a/api/dify_graph/variables/input_entities.py b/api/graphon/variables/input_entities.py similarity index 97% rename from api/dify_graph/variables/input_entities.py rename to api/graphon/variables/input_entities.py index e6a68ea359..c46ee47714 100644 --- a/api/dify_graph/variables/input_entities.py +++ b/api/graphon/variables/input_entities.py @@ -5,7 +5,7 @@ from typing import Any from jsonschema import Draft7Validator, SchemaError from pydantic import BaseModel, Field, field_validator -from dify_graph.file import FileTransferMethod, FileType +from graphon.file import FileTransferMethod, FileType class VariableEntityType(StrEnum): diff --git a/api/dify_graph/variables/segment_group.py b/api/graphon/variables/segment_group.py similarity index 100% rename from api/dify_graph/variables/segment_group.py rename to api/graphon/variables/segment_group.py diff --git a/api/dify_graph/variables/segments.py b/api/graphon/variables/segments.py similarity index 99% rename from api/dify_graph/variables/segments.py rename to api/graphon/variables/segments.py index bdb213ed48..8902ddc7e9 100644 --- a/api/dify_graph/variables/segments.py +++ b/api/graphon/variables/segments.py @@ -5,7 +5,7 @@ from typing import Annotated, Any, TypeAlias from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator -from dify_graph.file import File +from graphon.file import File from .types import SegmentType diff --git a/api/dify_graph/variables/types.py b/api/graphon/variables/types.py similarity index 93% rename from api/dify_graph/variables/types.py rename to api/graphon/variables/types.py index 53bf495a27..949a693ad2 100644 --- a/api/dify_graph/variables/types.py +++ b/api/graphon/variables/types.py @@ -4,10 +4,10 @@ from collections.abc import Mapping from enum import StrEnum from typing import TYPE_CHECKING, Any -from dify_graph.file.models import File +from graphon.file.models import File if TYPE_CHECKING: - from dify_graph.variables.segments import Segment + from graphon.variables.segments import Segment class ArrayValidation(StrEnum): @@ -220,8 +220,8 @@ class SegmentType(StrEnum): @staticmethod def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency - from factories import variable_factory + # Lazy import to avoid circular dependency between segment types and factory helpers. + from graphon.variables.factory import build_segment, build_segment_with_type match t: case ( @@ -231,19 +231,19 @@ class SegmentType(StrEnum): | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN ): - return variable_factory.build_segment_with_type(t, []) + return build_segment_with_type(t, []) case SegmentType.OBJECT: - return variable_factory.build_segment({}) + return build_segment({}) case SegmentType.STRING: - return variable_factory.build_segment("") + return build_segment("") case SegmentType.INTEGER: - return variable_factory.build_segment(0) + return build_segment(0) case SegmentType.FLOAT: - return variable_factory.build_segment(0.0) + return build_segment(0.0) case SegmentType.NUMBER: - return variable_factory.build_segment(0) + return build_segment(0) case SegmentType.BOOLEAN: - return variable_factory.build_segment(False) + return build_segment(False) case _: raise ValueError(f"unsupported variable type: {t}") diff --git a/api/dify_graph/variables/utils.py b/api/graphon/variables/utils.py similarity index 100% rename from api/dify_graph/variables/utils.py rename to api/graphon/variables/utils.py diff --git a/api/dify_graph/variables/variables.py b/api/graphon/variables/variables.py similarity index 100% rename from api/dify_graph/variables/variables.py rename to api/graphon/variables/variables.py diff --git a/api/dify_graph/workflow_type_encoder.py b/api/graphon/workflow_type_encoder.py similarity index 95% rename from api/dify_graph/workflow_type_encoder.py rename to api/graphon/workflow_type_encoder.py index 3dd846b3cb..7cdc83ebdb 100644 --- a/api/dify_graph/workflow_type_encoder.py +++ b/api/graphon/workflow_type_encoder.py @@ -4,8 +4,8 @@ from typing import Any, overload from pydantic import BaseModel -from dify_graph.file.models import File -from dify_graph.variables import Segment +from graphon.file.models import File +from graphon.variables import Segment class WorkflowRuntimeTypeConverter: diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index d4cb3e9971..8eeac37232 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -125,7 +125,8 @@ class BroadcastChannel(Protocol): a specific topic, all subscription should receive the published message. There are no restriction for the persistence of messages. Once a subscription is created, it - should receive all subsequent messages published. + should receive all subsequent messages published. However, a subscription should not receive + any message published before the subscription is established. `BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads. """ diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca..983f785027 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -63,21 +63,45 @@ class _StreamsSubscription(Subscription): def __init__(self, client: Redis | RedisCluster, key: str): self._client = client self._key = key - self._closed = threading.Event() - self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() - self._start_lock = threading.Lock() + + # The `_lock` lock is used to + # + # 1. protect the _listener attribute + # 2. prevent repeated releases of underlying resoueces. (The _closed flag.) + # + # INVARIANT: the implementation must hold the lock while + # reading and writing the _listener / `_closed` attribute. + self._lock = threading.Lock() + self._closed: bool = False + # self._closed = threading.Event() self._listener: threading.Thread | None = None def _listen(self) -> None: - try: - while not self._closed.is_set(): - streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + """The `_listen` method handles the message retrieval loop. It requires a dedicated thread + and is not intended for direct invocation. + The thread is started by `_start_if_needed`. + """ + + # since this method runs in a dedicated thread, acquiring `_lock` inside this method won't cause + # deadlock. + + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + last_id = "$" + try: + while True: + with self._lock: + if self._closed: + break + streams = self._client.xread({self._key: last_id}, block=1000, count=100) if not streams: continue - for _key, entries in streams: + for _, entries in streams: for entry_id, fields in entries: data = None if isinstance(fields, dict): @@ -89,37 +113,48 @@ class _StreamsSubscription(Subscription): data_bytes = bytes(data) if data_bytes is not None: self._queue.put_nowait(data_bytes) - self._last_id = entry_id + last_id = entry_id finally: self._queue.put_nowait(self._SENTINEL) - self._listener = None + with self._lock: + self._listener = None + self._closed = True def _start_if_needed(self) -> None: + """This method must be called with `_lock` held.""" if self._listener is not None: return # Ensure only one listener thread is created under concurrent calls - with self._start_lock: - if self._listener is not None or self._closed.is_set(): - return - self._listener = threading.Thread( - target=self._listen, - name=f"redis-streams-sub-{self._key}", - daemon=True, - ) - self._listener.start() + if self._listener is not None or self._closed: + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() def __iter__(self) -> Iterator[bytes]: # Iterator delegates to receive with timeout; stops on closure. - self._start_if_needed() - while not self._closed.is_set(): - item = self.receive(timeout=1) + with self._lock: + self._start_if_needed() + + while True: + with self._lock: + if self._closed: + return + try: + item = self.receive(timeout=1) + except SubscriptionClosedError: + return if item is not None: yield item def receive(self, timeout: float | None = 0.1) -> bytes | None: - if self._closed.is_set(): - raise SubscriptionClosedError("The Redis streams subscription is closed") - self._start_if_needed() + with self._lock: + if self._closed: + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() try: if timeout is None: @@ -129,29 +164,33 @@ class _StreamsSubscription(Subscription): except queue.Empty: return None - if item is self._SENTINEL or self._closed.is_set(): + if item is self._SENTINEL: raise SubscriptionClosedError("The Redis streams subscription is closed") assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" return bytes(item) def close(self) -> None: - if self._closed.is_set(): - return - self._closed.set() - listener = self._listener - if listener is not None: + with self._lock: + if self._closed: + return + self._closed = True + listener = self._listener + if listener is not None: + self._listener = None + # We close the listener outside of the with block to avoid holding the + # lock for a long time. + if listener is not None and listener.is_alive(): listener.join(timeout=2.0) if listener.is_alive(): logger.warning( "Streams subscription listener for key %s did not stop within timeout; keeping reference.", self._key, ) - else: - self._listener = None # Context manager helpers def __enter__(self) -> Self: - self._start_if_needed() + with self._lock: + self._start_if_needed() return self def __exit__(self, exc_type, exc_value, traceback) -> bool | None: diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py index c08578981b..e0a6ec2cac 100644 --- a/api/libs/datetime_utils.py +++ b/api/libs/datetime_utils.py @@ -2,7 +2,7 @@ import abc import datetime from typing import Protocol -import pytz +import pytz # type: ignore[import-untyped] class _NowFunction(Protocol): diff --git a/api/libs/helper.py b/api/libs/helper.py index e7572cc025..b1815859a5 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -21,9 +21,9 @@ from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account @@ -174,6 +174,18 @@ def normalize_uuid(value: str | UUID) -> str: raise ValueError("must be a valid UUID") from exc +def parse_uuid_str_or_none(value: str | None) -> str | None: + """ + Return None for missing/empty UUID-like values. + + Keep non-empty values unchanged to avoid changing behavior in paths that + currently pass placeholder IDs in tests/mocks. + """ + if value is None or not str(value).strip(): + return None + return str(value) + + UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d..dce332b01d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -18,15 +18,23 @@ if TYPE_CHECKING: from models.model import EndUser +def _resolve_current_user() -> EndUser | Account | None: + """ + Resolve the current user proxy to its underlying user object. + This keeps unit tests working when they patch `current_user` directly + instead of bootstrapping a full Flask-Login manager. + """ + user_proxy = current_user + get_current_object = getattr(user_proxy, "_get_current_object", None) + return get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + def current_account_with_tenant(): """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. """ - user_proxy = current_user - - get_current_object = getattr(user_proxy, "_get_current_object", None) - user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + user = _resolve_current_user() if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") @@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) - user = _get_user() + user = _resolve_current_user() if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. check_csrf_token(request, user.id) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index efce13f6f1..76e741301c 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,16 +1,19 @@ +import logging import sys import urllib.parse from dataclasses import dataclass from typing import NotRequired import httpx -from pydantic import TypeAdapter +from pydantic import TypeAdapter, ValidationError if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict +logger = logging.getLogger(__name__) + JsonObject = dict[str, object] JsonObjectList = list[JsonObject] @@ -25,13 +28,14 @@ class AccessTokenResponse(TypedDict, total=False): class GitHubEmailRecord(TypedDict, total=False): email: str primary: bool + verified: bool class GitHubRawUserInfo(TypedDict): id: int | str login: str - name: NotRequired[str] - email: NotRequired[str] + name: NotRequired[str | None] + email: NotRequired[str | None] class GoogleRawUserInfo(TypedDict): @@ -127,18 +131,52 @@ class GitHubOAuth(OAuth): response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) - email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) - email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) - primary_email = next((email for email in email_info if email.get("primary") is True), None) + # Only call the /user/emails endpoint when the profile email is absent, + # i.e. the user has "Keep my email addresses private" enabled. + resolved_email = user_info.get("email") or "" + if not resolved_email: + resolved_email = self._get_email_from_emails_endpoint(headers) - return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} + return {**user_info, "email": resolved_email} + + @staticmethod + def _get_email_from_emails_endpoint(headers: dict[str, str]) -> str: + """Fetch the best available email from GitHub's /user/emails endpoint. + + Prefers the primary email, then falls back to any verified email. + Returns an empty string when no usable email is found. + """ + try: + email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) + email_response.raise_for_status() + email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) + except (httpx.HTTPStatusError, ValidationError): + logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True) + return "" + + primary = next((r for r in email_records if r.get("primary") is True), None) + if primary: + return primary.get("email", "") + + # No primary email; try any verified email as a fallback. + verified = next((r for r in email_records if r.get("verified") is True), None) + if verified: + return verified.get("email", "") + + return "" def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) - email = payload.get("email") + email = payload.get("email") or "" if not email: - email = f"{payload['id']}+{payload['login']}@users.noreply.github.com" - return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email) + # When no email is available from the profile or /user/emails endpoint, + # fall back to GitHub's noreply address so sign-in can still proceed. + # Use only the numeric ID (not the login) so the address stays stable + # even if the user renames their GitHub account. + github_id = payload["id"] + email = f"{github_id}@users.noreply.github.com" + logger.info("GitHub user %s has no public email; using noreply address", payload["login"]) + return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email) class GoogleOAuth(OAuth): diff --git a/api/libs/schedule_utils.py b/api/libs/schedule_utils.py index 1ab5f499e9..b80a5ea722 100644 --- a/api/libs/schedule_utils.py +++ b/api/libs/schedule_utils.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -import pytz +import pytz # type: ignore[import-untyped] from croniter import croniter diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e6984..e323ccfd7f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -43,7 +43,9 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -135,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -494,7 +496,9 @@ class Document(Base): ) doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) + doc_form: Mapped[IndexStructureType] = mapped_column( + EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'") + ) doc_language = mapped_column(String(255), nullable=True) need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @@ -998,7 +1002,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/enums.py b/api/models/enums.py index 4849099d30..cdec7b2f12 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" @@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class ApiTokenType(StrEnum): + """API Token type""" + + APP = "app" + DATASET = "dataset" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py index d0bd34efec..b2d09a7732 100644 --- a/api/models/execution_extra_content.py +++ b/api/models/execution_extra_content.py @@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent): form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) @classmethod - def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": - return cls(form_id=form_id, message_id=message_id) + def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent": + return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id) form: Mapped["HumanInputForm"] = relationship( "HumanInputForm", diff --git a/api/models/human_input.py b/api/models/human_input.py index 48e7fbb9ea..b4c7a634b6 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,11 +6,8 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) +from core.workflow.human_input_compat import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index a08e43d128..bcb142db56 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,10 +3,11 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto +from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 @@ -20,17 +21,19 @@ from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 +from models.utils.file_input_compat import build_file_from_input_mapping from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db from .enums import ( + ApiTokenType, AppMCPServerStatus, AppStatus, BannerStatus, @@ -43,6 +46,8 @@ from .enums import ( MessageChainType, MessageFileBelongsTo, MessageStatus, + ProviderQuotaType, + TagType, ) from .provider_ids import GenericProviderID from .types import EnumText, LongText, StringUUID @@ -54,6 +59,32 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def _resolve_app_tenant_id(app_id: str) -> str: + resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not resolved_tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return resolved_tenant_id + + +def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: + resolved_tenant_id = owner_tenant_id + + def resolve_owner_tenant_id() -> str: + nonlocal resolved_tenant_id + if resolved_tenant_id is None: + resolved_tenant_id = _resolve_app_tenant_id(app_id) + return resolved_tenant_id + + return resolve_owner_tenant_id + + class EnabledConfig(TypedDict): enabled: bool @@ -586,7 +617,9 @@ class AppModelConfig(TypeBase): __tablename__ = "app_model_configs" __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) @@ -935,7 +968,9 @@ class AccountTrialAppRecord(Base): class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1050,23 +1085,26 @@ class Conversation(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: stored input payloads may come from before or after the + # graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant + # resolution at the SQLAlchemy model boundary instead of pushing ownership back + # into `graphon.file.File`. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) # Convert file mapping to File object for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1079,15 +1117,12 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1395,21 +1430,23 @@ class Message(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: message inputs are persisted as JSON and must remain + # readable across file payload shape changes. Do not assume `tenant_id` + # is serialized into each file mapping going forward. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1422,15 +1459,12 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1605,6 +1639,7 @@ class Message(Base): "upload_file_id": message_file.upload_file_id, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: @@ -1618,6 +1653,7 @@ class Message(Base): "url": message_file.url, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: @@ -1632,6 +1668,7 @@ class Message(Base): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) else: raise ValueError( @@ -1782,7 +1819,7 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) @@ -1844,7 +1881,9 @@ class AppAnnotationHitHistory(TypeBase): sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) source: Mapped[str] = mapped_column(LongText, nullable=False) @@ -2094,7 +2133,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(sa.DateTime, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -2404,7 +2443,7 @@ class Tag(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - type: Mapped[str] = mapped_column(String(16), nullable=False) + type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -2489,7 +2528,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/tools.py b/api/models/tools.py index c09f054e7d..63b27b9413 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolProviderType, + WorkflowToolParameterConfiguration, +) from .base import TypeBase from .engine import db from .model import Account, App, Tenant -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from core.entities.mcp_provider import MCPProviderEntity @@ -141,7 +145,9 @@ class ApiToolProvider(TypeBase): icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema: Mapped[str] = mapped_column(LongText, nullable=False) - schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) + schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column( + EnumText(ApiProviderSchemaType, length=40), nullable=False + ) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -208,7 +214,7 @@ class ToolLabelBinding(TypeBase): # tool id tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # label name label_name: Mapped[str] = mapped_column(String(40), nullable=False) @@ -386,7 +392,7 @@ class ToolModelInvoke(TypeBase): # provider provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # tool name tool_name: Mapped[str] = mapped_column(String(128), nullable=False) # invoke parameters diff --git a/api/models/utils/__init__.py b/api/models/utils/__init__.py new file mode 100644 index 0000000000..b390b8106b --- /dev/null +++ b/api/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .file_input_compat import build_file_from_input_mapping + +__all__ = ["build_file_from_input_mapping"] diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py new file mode 100644 index 0000000000..dee1cc507a --- /dev/null +++ b/api/models/utils/file_input_compat.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from functools import lru_cache +from typing import Any + +from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod + + +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + + return None + + +def resolve_file_mapping_tenant_id( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> str: + tenant_id = file_mapping.get("tenant_id") + if isinstance(tenant_id, str) and tenant_id: + return tenant_id + + return tenant_resolver() + + +def build_file_from_stored_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_id: str, +) -> File: + """ + Canonicalize a persisted file payload against the current tenant context. + + Stored JSON rows can outlive file schema changes, so rebuild storage-backed + files through the workflow factory instead of trusting serialized metadata. + Pure external ``REMOTE_URL`` payloads without a backing upload row are + passed through because there is no server-owned record to rebind. + """ + + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + mapping = dict(file_mapping) + mapping.pop("tenant_id", None) + record_id = resolve_file_record_id(mapping) + transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) + + if transfer_method == FileTransferMethod.TOOL_FILE and record_id: + mapping["tool_file_id"] = record_id + elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: + mapping["upload_file_id"] = record_id + elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: + mapping["datasource_file_id"] = record_id + + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + if isinstance(url, str) and url: + mapping["remote_url"] = url + return File.model_validate(mapping) + + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_get_file_access_controller(), + ) + + +def build_file_from_input_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> File: + """ + Rehydrate persisted model input payloads into graph `File` objects. + + This compatibility layer exists because model JSON rows can outlive file payload + schema changes. Legacy rows may carry `related_id` and `tenant_id`, while newer + rows may only carry `reference`. Keep ownership resolution here, at the model + boundary, instead of pushing tenant data back into `graphon.file.File`. + """ + + transfer_method = FileTransferMethod.value_of(file_mapping["transfer_method"]) + record_id = resolve_file_record_id(file_mapping) + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id="", + ) + + tenant_id = resolve_file_mapping_tenant_id(file_mapping=file_mapping, tenant_resolver=tenant_resolver) + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id=tenant_id, + ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 6e8dda429d..0557e2e890 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -24,19 +24,26 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import ( +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey -from dify_graph.file.constants import maybe_file_object -from dify_graph.file.models import File -from dify_graph.variables import utils as variable_utils -from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file.constants import maybe_file_object +from graphon.file.models import File +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -48,8 +55,8 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from dify_graph.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account @@ -57,6 +64,7 @@ from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID +from .utils.file_input_compat import build_file_from_stored_mapping logger = logging.getLogger(__name__) @@ -64,6 +72,15 @@ SerializedWorkflowValue = dict[str, Any] SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] +def _resolve_workflow_app_tenant_id(app_id: str) -> str: + from .model import App + + tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return tenant_id + + class WorkflowContentDict(TypedDict): graph: Mapping[str, Any] features: dict[str, Any] @@ -273,7 +290,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(node_config) + return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -418,7 +435,7 @@ class Workflow(Base): # bug "selected": false, } - For specific node type, refer to `dify_graph.nodes` + For specific node type, refer to `graphon.nodes` """ graph_dict = self.graph_dict if "nodes" not in graph_dict: @@ -930,7 +947,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo inputs: Mapped[str | None] = mapped_column(LongText) process_data: Mapped[str | None] = mapped_column(LongText) outputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[str] = mapped_column(String(255)) + status: Mapped[WorkflowNodeExecutionStatus] = mapped_column(EnumText(WorkflowNodeExecutionStatus, length=255)) error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) @@ -1221,7 +1238,9 @@ class WorkflowAppLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False + ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1301,10 +1320,14 @@ class WorkflowArchiveLog(TypeBase): log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True + ) run_version: Mapped[str] = mapped_column(String(255), nullable=False) - run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), nullable=False + ) run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False ) @@ -1444,7 +1467,7 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. # - # ref: api/dify_graph/entities/variable_pool.py:18 + # ref: api/graphon/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), @@ -1559,10 +1582,9 @@ class WorkflowDraftVariable(Base): def _loads_value(self) -> Segment: value = json.loads(self.value) - return self.build_segment_with_type(self.value_type, value) + return self.build_segment_from_serialized_value(self.value_type, value) - @staticmethod - def rebuild_file_types(value: Any): + def _rebuild_file_types(self, value: Any): # NOTE(QuantumGhost): Temporary workaround for structured data handling. # By this point, `output` has been converted to dict by # `WorkflowEntry.handle_special_values`, so we need to @@ -1576,13 +1598,72 @@ class WorkflowDraftVariable(Base): if isinstance(value, dict): if not maybe_file_object(value): return cast(Any, value) - return File.model_validate(value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + return build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], value), + tenant_id=tenant_id, + ) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] if not maybe_file_object(first): return cast(Any, value) - file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + file_list: list[File] = [] + for item in value_list: + file_list.append( + build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], item), + tenant_id=tenant_id, + ) + ) + return cast(Any, file_list) + else: + return cast(Any, value) + + def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment: + # Persisted draft variable rows may contain historical file payloads. + # Rebuild them through the file factory so tenant ownership, signed URLs, + # and storage-backed metadata come from canonical records instead of the + # serialized JSON blob. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + @staticmethod + def rebuild_file_types(value: Any): + # Keep the class-level fallback for callers that only need lightweight + # structural reconstruction. Persisted draft-variable payloads should go + # through `build_segment_from_serialized_value()` so file metadata is + # rebuilt from canonical storage records. + if isinstance(value, dict): + if not maybe_file_object(value): + return cast(Any, value) + normalized_file = dict(value) + normalized_file.pop("tenant_id", None) + return File.model_validate(normalized_file) + elif isinstance(value, list) and value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + file_list: list[File] = [] + for item in value_list: + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(File.model_validate(normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/pyproject.toml b/api/pyproject.toml index f824fe7c23..b1f1f4bb2e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.13.2" +version = "1.13.3" requires-python = ">=3.11,<3.13" dependencies = [ @@ -8,7 +8,7 @@ dependencies = [ "arize-phoenix-otel~=0.15.0", "azure-identity==1.25.3", "beautifulsoup4==4.14.3", - "boto3==1.42.68", + "boto3==1.42.73", "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.6.2", @@ -23,7 +23,7 @@ dependencies = [ "gevent~=25.9.1", "gmpy2~=2.3.0", "google-api-core>=2.19.1", - "google-api-python-client==2.192.0", + "google-api-python-client==2.193.0", "google-auth>=2.47.0", "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", @@ -40,7 +40,7 @@ dependencies = [ "numpy~=1.26.4", "openpyxl~=3.1.5", "opik~=1.10.37", - "litellm==1.82.2", # Pinned to avoid madoka dependency issue + "litellm==1.82.6", # Pinned to avoid madoka dependency issue "opentelemetry-api==1.28.0", "opentelemetry-distro==0.49b0", "opentelemetry-exporter-otlp==1.28.0", @@ -72,13 +72,14 @@ dependencies = [ "pyyaml~=6.0.1", "readabilipy~=0.3.0", "redis[hiredis]~=7.3.0", - "resend~=2.23.0", - "sentry-sdk[flask]~=2.54.0", + "resend~=2.26.0", + "sentry-sdk[flask]~=2.55.0", "sqlalchemy~=2.0.29", - "starlette==0.52.1", + "starlette==1.0.0", "tiktoken~=0.12.0", "transformers~=5.3.0", "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", + "pypandoc~=1.13", "yarl~=1.23.0", "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", @@ -91,7 +92,7 @@ dependencies = [ "apscheduler>=3.11.0", "weave>=0.52.16", "fastopenapi[flask]>=0.7.0", - "bleach~=6.2.0", + "bleach~=6.3.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -118,7 +119,7 @@ dev = [ "ruff~=0.15.5", "pytest~=9.0.2", "pytest-benchmark~=5.2.3", - "pytest-cov~=7.0.0", + "pytest-cov~=7.1.0", "pytest-env~=1.6.0", "pytest-mock~=3.15.1", "testcontainers~=4.14.1", @@ -173,7 +174,7 @@ dev = [ "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.55.0", + "pyrefly>=0.57.1", ] ############################################################ @@ -202,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] # Required by vector store clients ############################################################ vdb = [ - "alibabacloud_gpdb20160503~=3.8.0", + "alibabacloud_gpdb20160503~=5.1.0", "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", "clickhouse-connect~=0.14.1", @@ -230,26 +231,6 @@ vdb = [ "holo-search-sdk>=0.4.1", ] -[tool.mypy] - -[[tool.mypy.overrides]] -# targeted ignores for current type-check errors -# TODO(QuantumGhost): suppress type errors in HITL related code. -# fix the type error later -module = [ - "configs.middleware.cache.redis_pubsub_config", - "extensions.ext_redis", - "tasks.workflow_execution_tasks", - "dify_graph.nodes.base.node", - "services.human_input_delivery_test_service", - "core.app.apps.advanced_chat.app_generator", - "controllers.console.human_input_form", - "controllers.console.app.workflow_run", - "repositories.sqlalchemy_api_workflow_node_execution_repository", - "extensions.logstore.repositories.logstore_api_workflow_run_repository", -] -ignore_errors = true - [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index ad3c1e8389..cf002df2a9 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -109,34 +109,43 @@ core/trigger/debug/event_selectors.py core/trigger/entities/entities.py core/trigger/provider.py core/workflow/workflow_entry.py -dify_graph/entities/workflow_execution.py -dify_graph/file/file_manager.py -dify_graph/graph_engine/error_handler.py -dify_graph/graph_engine/layers/execution_limits.py -dify_graph/nodes/agent/agent_node.py -dify_graph/nodes/base/node.py -dify_graph/nodes/code/code_node.py -dify_graph/nodes/datasource/datasource_node.py -dify_graph/nodes/document_extractor/node.py -dify_graph/nodes/human_input/human_input_node.py -dify_graph/nodes/if_else/if_else_node.py -dify_graph/nodes/iteration/iteration_node.py -dify_graph/nodes/knowledge_index/knowledge_index_node.py +enterprise/telemetry/contracts.py +enterprise/telemetry/draft_trace.py +enterprise/telemetry/enterprise_trace.py +enterprise/telemetry/entities/__init__.py +enterprise/telemetry/event_handlers.py +enterprise/telemetry/exporter.py +enterprise/telemetry/id_generator.py +enterprise/telemetry/metric_handler.py +enterprise/telemetry/telemetry_log.py +graphon/entities/workflow_execution.py +graphon/file/file_manager.py +graphon/graph_engine/error_handler.py +graphon/graph_engine/layers/execution_limits.py +graphon/nodes/agent/agent_node.py +graphon/nodes/base/node.py +graphon/nodes/code/code_node.py +graphon/nodes/datasource/datasource_node.py +graphon/nodes/document_extractor/node.py +graphon/nodes/human_input/human_input_node.py +graphon/nodes/if_else/if_else_node.py +graphon/nodes/iteration/iteration_node.py +graphon/nodes/knowledge_index/knowledge_index_node.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -dify_graph/nodes/list_operator/node.py -dify_graph/nodes/llm/node.py -dify_graph/nodes/loop/loop_node.py -dify_graph/nodes/parameter_extractor/parameter_extractor_node.py -dify_graph/nodes/question_classifier/question_classifier_node.py -dify_graph/nodes/start/start_node.py -dify_graph/nodes/template_transform/template_transform_node.py -dify_graph/nodes/tool/tool_node.py -dify_graph/nodes/trigger_plugin/trigger_event_node.py -dify_graph/nodes/trigger_schedule/trigger_schedule_node.py -dify_graph/nodes/trigger_webhook/node.py -dify_graph/nodes/variable_aggregator/variable_aggregator_node.py -dify_graph/nodes/variable_assigner/v1/node.py -dify_graph/nodes/variable_assigner/v2/node.py +graphon/nodes/list_operator/node.py +graphon/nodes/llm/node.py +graphon/nodes/loop/loop_node.py +graphon/nodes/parameter_extractor/parameter_extractor_node.py +graphon/nodes/question_classifier/question_classifier_node.py +graphon/nodes/start/start_node.py +graphon/nodes/template_transform/template_transform_node.py +graphon/nodes/tool/tool_node.py +graphon/nodes/trigger_plugin/trigger_event_node.py +graphon/nodes/trigger_schedule/trigger_schedule_node.py +graphon/nodes/trigger_webhook/node.py +graphon/nodes/variable_aggregator/variable_aggregator_node.py +graphon/nodes/variable_assigner/v1/node.py +graphon/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 2fa065bcc8..3595ea33f0 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index a96c4acb31..ffc17e92cf 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -40,9 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index be28b7e613..03ce574dca 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from dify_graph.entities.pause_reason import PauseReason +from graphon.entities.pause_reason import PauseReason class WorkflowPauseEntity(ABC): diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 77e40fc6fc..44735eb769 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -14,7 +14,7 @@ from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index fdd3e123e4..5bb0c74ada 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -33,17 +33,17 @@ from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,25 +61,13 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], ) -> HumanInputRequired: form_content = "" inputs = [] actions = [] - display_in_ui = False resolved_default_values: dict[str, Any] = {} node_title = "Human Input" form_id = reason_model.form_id @@ -99,25 +87,16 @@ def _build_human_input_required_reason( form_content = definition.form_content inputs = list(definition.inputs) actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - return HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, actions=actions, - display_in_ui=display_in_ui, node_id=node_id, node_title=node_title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -823,22 +802,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id ] form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} if form_ids: form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + pause_reasons.append(_build_human_input_required_reason(reason, form_model)) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 508db22eb0..67f8795d3f 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -18,9 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from dify_graph.nodes.human_input.entities import FormDefinition -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8b9d973d6d..6ceb3ef856 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -8,6 +8,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -57,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5..10003b1b97 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 68cb3438ca..643a2a2a84 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -27,15 +27,15 @@ from core.trigger.constants import ( ) from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..9413a93fc4 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -12,10 +12,10 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -92,7 +92,7 @@ class AppService: default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "") # get default model instance try: @@ -124,11 +124,19 @@ class AppService: "completion_params": {}, } else: - provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM - ) - default_model_config["model"]["provider"] = provider - default_model_config["model"]["name"] = model + try: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except Exception: + logger.exception("Get default provider model failed, tenant_id: %s", tenant_id) + provider = default_model_config["model"].get("provider") + model = default_model_config["model"].get("name") + + if provider: + default_model_config["model"]["provider"] = provider + if model: + default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] default_model_config["model"] = json.dumps(default_model_dict) @@ -197,6 +205,7 @@ class AppService: tenant_id=current_user.current_tenant_id, app_id=app.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, @@ -241,7 +250,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +266,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) @@ -266,6 +281,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_name(self, app: App, name: str) -> App: @@ -281,6 +298,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: @@ -298,6 +317,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_site_status(self, app: App, enable_site: bool) -> App: @@ -315,6 +336,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_api_status(self, app: App, enable_api: bool) -> App: @@ -333,6 +356,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def delete_app(self, app: App): @@ -340,6 +365,8 @@ class AppService: Delete app :param app: App instance """ + app_was_deleted.send(app) + db.session.delete(app) db.session.commit() diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index d556230044..6e9d6b1c73 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -7,8 +7,8 @@ new GraphEngine command channel mechanism. from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.graph_engine.manager import GraphEngineManager from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1794ea9947..9e743bf7b1 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -9,8 +9,8 @@ from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( @@ -61,7 +61,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) @@ -71,7 +71,7 @@ class AudioService: buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + return {"text": model_instance.invoke_speech2text(file=buffer)} @classmethod def transcript_tts( @@ -109,7 +109,7 @@ class AudioService: voice = cast(str | None, text_to_speech_dict.get("voice")) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) @@ -123,9 +123,7 @@ class AudioService: else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( - content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) except Exception as e: raise e @@ -155,7 +153,7 @@ class AudioService: @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b5..2e1b723e82 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b0768..6e183b70e3 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b002706931..c9e5610aea 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d1..cbdc908690 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 6fe61a1a52..755407d849 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -389,7 +389,11 @@ class BillingService: # Redis returns bytes, decode to string and parse JSON json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value plan_dict = json.loads(json_str) + # NOTE (hj24): New billing versions may return timestamp as str, and validate_python + # in non-strict mode will coerce it to the expected int type. + # To preserve compatibility, always keep non-strict mode here and avoid strict mode. subscription_plan = subscription_adapter.validate_python(plan_dict) + # NOTE END tenant_plans[tenant_id] = subscription_plan except Exception: logger.exception( diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 0e0eab00ad..c6b32b373e 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -10,10 +10,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 566c27c0f3..545c5048d5 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -10,9 +10,9 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index f00e3fe01e..287d513f48 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.variables.variables import VariableBase +from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571..2894826935 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..3e2342b1a7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,16 +21,16 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -227,8 +228,8 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": - model_manager = ModelManager() + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name) @@ -253,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -348,9 +352,9 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -367,7 +371,7 @@ class DatasetService: @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, @@ -384,7 +388,7 @@ class DatasetService: @staticmethod def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, provider=model_provider, @@ -405,7 +409,7 @@ class DatasetService: @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=reranking_model_provider, @@ -716,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -742,7 +746,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( @@ -860,7 +864,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) try: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -952,9 +956,9 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": - model_manager = ModelManager() + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", @@ -975,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -990,13 +994,13 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=knowledge_configuration.embedding_model_provider, @@ -1017,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1028,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1049,7 +1053,7 @@ class DatasetService: or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = None try: embedding_model = model_manager.get_model_instance( @@ -1088,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1439,7 +1443,7 @@ class DocumentService: .filter( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, - Document.doc_form != "qa_model", # Skip qa_model documents + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .update({Document.need_summary: need_summary}, synchronize_session=False) ) @@ -1906,9 +1910,9 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2039,7 +2043,7 @@ class DocumentService: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = knowledge_config.doc_form + document.doc_form = IndexStructureType(knowledge_config.doc_form) document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch @@ -2220,7 +2224,7 @@ class DocumentService: # dataset.indexing_technique = knowledge_config.indexing_technique # if knowledge_config.indexing_technique == "high_quality": - # model_manager = ModelManager() + # model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: # dataset_embedding_model = knowledge_config.embedding_model # dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2639,7 +2643,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = document_data.doc_form + document.doc_form = IndexStructureType(document_data.doc_form) db.session.add(document) db.session.commit() # update document segment @@ -2688,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2711,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3100,7 +3104,7 @@ class DocumentService: class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") if not args["answer"].strip(): @@ -3124,8 +3128,8 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3157,7 +3161,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] @@ -3207,8 +3211,8 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3229,9 +3233,9 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( texts=[content + segment_item["answer"]] )[0] @@ -3254,7 +3258,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count @@ -3321,7 +3325,7 @@ class SegmentService: content = args.content or segment.content if segment.content == content: segment.word_count = len(content) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3344,9 +3348,9 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3381,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3408,8 +3412,8 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3418,7 +3422,7 @@ class SegmentService: ) # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: @@ -3435,7 +3439,7 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3448,9 +3452,9 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3480,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( @@ -3786,7 +3790,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3849,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index f3b2adb965..2b7bebb01e 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -14,9 +14,9 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter -from dify_graph.model_runtime.entities.provider_entities import FormType from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 9dd595f516..6679c08ebd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -15,9 +15,9 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, ModelCredentialSchema, ProviderCredentialSchema, diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 4cf42b7f44..d2fa98f5e2 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -9,8 +9,8 @@ from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition -from dify_graph.nodes.http_request.exc import InvalidHttpMethodError from extensions.ext_database import db +from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, diff --git a/api/services/file_service.py b/api/services/file_service.py index a7060f3b92..c11f018f52 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,10 +20,10 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 9993d24c70..d490ad1561 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,8 +9,8 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db +from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 229e6608da..68ef67dec1 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,16 +8,16 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService @@ -177,21 +177,21 @@ class EmailDeliveryTestHandler: def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] - member_user_ids: list[str] = [] + bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) + bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) - if recipients.whole_workspace: - member_user_ids = [] + if recipients.include_bound_group: + bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: + elif bound_reference_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) + for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) @@ -220,7 +220,7 @@ class EmailDeliveryTestHandler: stmt = stmt.where(Account.id.in_(unique_ids)) with self._session_factory() as session: - rows = session.execute(stmt).all() + rows = session.execute(stmt).tuples().all() return dict(rows) @staticmethod diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 2e74c50963..76598d31ac 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -11,12 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( +from graphon.nodes.human_input.entities import ( FormDefinition, HumanInputSubmissionValidationError, validate_human_input_submission, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc..3d6fdb08a3 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,12 +1,20 @@ import boto3 +from pydantic import BaseModel, Field from configs import dify_config +class BedrockRetrievalSetting(BaseModel): + """Retrieval settings for Amazon Bedrock knowledge base queries.""" + + top_k: int | None = Field(default=None, description="Maximum number of results to retrieve") + score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") + + class ExternalDatasetTestService: # this service is only for internal testing @staticmethod - def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str): # get bedrock client client = boto3.client( "bedrock-agent-runtime", @@ -20,7 +28,7 @@ class ExternalDatasetTestService: knowledgeBaseId=knowledge_id, retrievalConfiguration={ "vectorSearchConfiguration": { - "numberOfResults": retrieval_setting.get("top_k"), + "numberOfResults": retrieval_setting.top_k, "overrideSearchType": "HYBRID", } }, @@ -33,7 +41,7 @@ class ExternalDatasetTestService: retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: # filter out results with score less than threshold - if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + if retrieval_result.get("score") < retrieval_setting.score_threshold: continue result = { "metadata": retrieval_result.get("metadata"), diff --git a/api/services/message_service.py b/api/services/message_service.py index fc87802f51..0c4a334b47 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -12,8 +12,8 @@ from core.model_manager import ModelManager from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating @@ -255,7 +255,7 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bf3b6db3ed..469357d6e0 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,14 +10,15 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from extensions.ext_database import db +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential @@ -26,8 +27,9 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ @@ -40,7 +42,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -61,7 +63,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -83,7 +85,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -222,8 +224,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -310,7 +312,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -495,8 +497,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -532,6 +534,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=assembly.model_provider_factory, ) def _custom_credentials_validate( @@ -542,6 +545,7 @@ class ModelLoadBalancingService: model: str, credentials: dict, load_balancing_model_config: LoadBalancingModelConfig | None = None, + model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, ): """ @@ -552,6 +556,7 @@ class ModelLoadBalancingService: :param model: model name :param credentials: credentials :param load_balancing_model_config: load balancing model config + :param model_provider_factory: model provider factory sharing the active runtime :param validate: validate credentials :return: """ @@ -581,7 +586,8 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = ModelProviderFactory(tenant_id) + if model_provider_factory is None: + model_provider_factory = provider_configuration.get_model_provider_factory() if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0ddd6b9b1a..e634f90603 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,9 @@ import logging from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -25,8 +25,9 @@ class ModelProviderService: Model Provider Service """ - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def _get_provider_configuration(self, tenant_id: str, provider: str): """ @@ -43,7 +44,7 @@ class ModelProviderService: ProviderNotFoundError: If provider doesn't exist """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -60,7 +61,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_responses = [] for provider_configuration in provider_configurations.values(): @@ -138,7 +139,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models return [ @@ -146,6 +147,26 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] + def get_provider_available_credentials(self, tenant_id: str, provider: str): + return self._get_provider_manager(tenant_id).get_provider_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) + + def get_provider_model_available_credentials( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + ): + return self._get_provider_manager(tenant_id).get_provider_model_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type, + model_name=model, + ) + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -391,7 +412,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True) @@ -476,7 +497,9 @@ class ModelProviderService: model_type_enum = ModelType.value_of(model_type) try: - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + result = self._get_provider_manager(tenant_id).get_default_model( + tenant_id=tenant_id, model_type=model_type_enum + ) return ( DefaultModelResponse( model=result.model, @@ -507,7 +530,7 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - self.provider_manager.update_default_model_record( + self._get_provider_manager(tenant_id).update_default_model_record( tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) @@ -523,7 +546,7 @@ class ModelProviderService: :param lang: language (zh_Hans or en_US) :return: """ - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) return byte_data, mime_type diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 296b9f0890..46a6221fcc 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -34,26 +34,33 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping +from core.workflow.system_variables import ( + SystemVariableKey, + build_bootstrap_variables, + build_system_variables, + default_system_variables, + get_system_segment, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.workflow_node_execution import ( +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace +from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.graph_events.base import GraphNodeEventBase -from dify_graph.node_events.base import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.variables import VariableBase -from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.graph_events.base import GraphNodeEventBase +from graphon.node_events.base import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore @@ -88,6 +95,12 @@ from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) +def _build_seeded_variable_pool(variables: Sequence[Variable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + return variable_pool + + class RagPipelineService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize RagPipelineService with repository dependencies.""" @@ -521,13 +534,7 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ), + variable_pool=_build_seeded_variable_pool(default_system_variables()), variable_loader=DraftVarLoader( engine=db.engine, app_id=pipeline.id, @@ -571,6 +578,13 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + if workflow_node_execution_db_model is not None: + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=account.id, + ) return workflow_node_execution_db_model def run_datasource_workflow_node( @@ -959,10 +973,10 @@ class RagPipelineService: workflow_node_execution.error = error # update document status variable_pool = node_instance.graph_runtime_state.variable_pool - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + invoke_from = get_system_segment(variable_pool, SystemVariableKey.INVOKE_FROM) if invoke_from: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: @@ -1276,7 +1290,7 @@ class RagPipelineService: else: enclosing_node_id = None - system_inputs = SystemVariable( + system_inputs = build_system_variables( datasource_type=args.get("datasource_type", "online_document"), datasource_info=args.get("datasource_info", {}), ) @@ -1287,12 +1301,11 @@ class RagPipelineService: node_id=node_id, user_inputs={}, user_id=current_user.id, - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs={}, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], + variable_pool=_build_seeded_variable_pool( + build_bootstrap_variables( + system_variables=system_inputs, + rag_pipeline_variables=(), + ) ), variable_loader=DraftVarLoader( engine=db.engine, @@ -1334,6 +1347,12 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=current_user.id, + ) return workflow_node_execution_db_model def get_recommended_plugins(self, type: str) -> dict: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d3..1b8207cc31 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,17 +22,18 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 1d0aafd5fd..215a8c8528 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,6 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -79,9 +80,9 @@ class RagPipelineTransformService: pipeline = self._create_pipeline(pipeline_yaml) # save chunk structure to dataset - if doc_form == "hierarchical_model": + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset.chunk_structure = "hierarchical_model" - elif doc_form == "text_model": + elif doc_form == IndexStructureType.PARAGRAPH_INDEX: dataset.chunk_structure = "text_model" else: raise ValueError("Unsupported doc form") @@ -101,38 +102,38 @@ class RagPipelineTransformService: def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} - if doc_form == "text_model": + if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") - elif doc_form == "hierarchical_model": + elif doc_form == IndexStructureType.PARENT_CHILD_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml @@ -169,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 00a2144800..c91f621ffb 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -31,9 +31,9 @@ from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.enums import WorkflowType from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972b..4334412c8b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,10 +12,11 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -140,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -191,7 +192,7 @@ class SummaryIndexService: # Calculate embedding tokens for summary (for logging and statistics) embedding_tokens = 0 try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -200,7 +201,8 @@ class SummaryIndexService: ) if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) - embedding_tokens = tokens_list[0] if tokens_list else 0 + raw_embedding_tokens = tokens_list[0] if tokens_list else 0 + embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0 except Exception as e: logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) @@ -724,7 +726,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -851,7 +853,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -889,7 +891,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -981,7 +983,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1012,7 +1014,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf4..70bf7f16f2 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding @@ -83,7 +84,7 @@ class TagService: raise ValueError("Tag name already exists") tag = Tag( name=args["name"], - type=args["type"], + type=TagType(args["type"]), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 408b1c22d1..9190a67249 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -20,8 +20,8 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 101b2fe5a2..931ca5021a 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -12,8 +12,8 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 7e9d010d2f..a827222c1d 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -13,7 +13,7 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from dify_graph.entities.graph_config import NodeConfigDict +from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 24bbeda329..dca00a466b 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -18,9 +18,9 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 3bc64423b3..774068895b 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -15,6 +15,7 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( @@ -23,13 +24,13 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, WebhookParameter, ) -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.file.models import FileTransferMethod -from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.file.models import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger @@ -46,6 +47,7 @@ except ImportError: magic = None # type: ignore[assignment] logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WebhookService: @@ -422,6 +424,7 @@ class WebhookService: return file_factory.build_from_mapping( mapping=mapping, tenant_id=webhook_trigger.tenant_id, + access_controller=_file_access_controller, ) @classmethod diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 60dc1dedb8..d0a4317065 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -6,9 +6,9 @@ from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config -from dify_graph.file.models import File -from dify_graph.nodes.variable_assigner.common.helpers import UpdatedVariable -from dify_graph.variables.segments import ( +from graphon.file.models import File +from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable +from graphon.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -20,7 +20,7 @@ from dify_graph.variables.segments import ( Segment, StringSegment, ) -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..5fd310b689 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,12 +4,12 @@ from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument @@ -45,9 +45,9 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -112,7 +112,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +197,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +237,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +252,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f0596e44c8..1f3993505c 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -17,13 +17,13 @@ from core.app.apps.completion.app_config_manager import CompletionAppConfigManag from core.helper import encrypter from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.models import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db +from graphon.file.models import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 9489618762..fa26f507ee 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from typing_extensions import TypedDict -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f124e137c3..0b5c89e574 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,28 +14,36 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.trigger.constants import is_trigger_node_type -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, SystemVariableKey -from dify_graph.file.models import File -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables import Segment, StringSegment, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import ( - ArrayFileSegment, - FileSegment, +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from graphon.enums import NodeType +from graphon.file.models import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation from models.enums import ConversationFromSource, DraftVariableType +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -71,7 +79,7 @@ class UpdateNotSupportedError(WorkflowDraftVariableError): class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # - # ref: dify_graph.variable_loader.VariableLoader + # ref: graphon.variable_loader.VariableLoader # Database engine used for loading variables. _engine: Engine @@ -120,7 +128,11 @@ class DraftVarLoader(VariableLoader): elif isinstance(value, ArrayFileSegment): files.extend(value.value) with Session(bind=self._engine) as session: - storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader = StorageKeyLoader( + session, + tenant_id=self._tenant_id, + access_controller=DatabaseFileAccessController(), + ) storage_key_loader.load_storage_keys(files) offloaded_draft_vars = [] @@ -174,7 +186,7 @@ class DraftVarLoader(VariableLoader): return (draft_var.node_id, draft_var.name), variable deserialized = json.loads(content) - segment = WorkflowDraftVariable.build_segment_with_type(variable_file.value_type, deserialized) + segment = draft_var.build_segment_from_serialized_value(variable_file.value_type, deserialized) variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), @@ -838,6 +850,12 @@ class DraftVariableSaver: self._user = user self._enclosing_node_id = enclosing_node_id + def _resolve_app_tenant_id(self) -> str: + tenant_id = self._session.scalar(select(App.tenant_id).where(App.id == self._app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {self._app_id}") + return tenant_id + def _create_dummy_output_variable(self): return WorkflowDraftVariable.new_node_variable( app_id=self._app_id, @@ -892,27 +910,18 @@ class DraftVariableSaver: for name, value in output.items(): value_seg = _build_segment_for_serialized_values(value) node_id, name = self._normalize_variable_for_start_node(name) - # If node_id is not `sys`, it means that the variable is a user-defined input field - # in `Start` node. - if node_id != SYSTEM_VARIABLE_NODE_ID: - draft_vars.append( - WorkflowDraftVariable.new_node_variable( - app_id=self._app_id, - user_id=self._user.id, - node_id=self._node_id, - name=name, - node_execution_id=self._node_execution_id, - value=value_seg, - visible=True, - editable=True, - ) - ) - has_non_sys_variables = True - else: + if node_id == SYSTEM_VARIABLE_NODE_ID: if name == SystemVariableKey.FILES: # Here we know the type of variable must be `array[file]`, we - # just build files from the value. - files = [File.model_validate(v) for v in value] + # just rebuild files from the serialized payload. + tenant_id = self._resolve_app_tenant_id() + files = [ + build_file_from_stored_mapping( + file_mapping=v, + tenant_id=tenant_id, + ) + for v in value + ] if files: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: @@ -928,15 +937,47 @@ class DraftVariableSaver: editable=self._should_variable_be_editable(node_id, name), ) ) + elif node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + user_id=self._user.id, + name=name, + value=value_seg, + ) + ) + has_non_sys_variables = True + else: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + user_id=self._user.id, + node_id=node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(node_id, self._node_type, name), + editable=self._should_variable_be_editable(node_id, name), + ) + ) + has_non_sys_variables = True if not has_non_sys_variables: draft_vars.append(self._create_dummy_output_variable()) return draft_vars def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: - if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): - return self._node_id, name - _, name_ = name.split(".", maxsplit=1) - return SYSTEM_VARIABLE_NODE_ID, name_ + for reserved_node_id in ( + SYSTEM_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + CONVERSATION_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + ): + prefix = f"{reserved_node_id}." + if name.startswith(prefix): + _, name_ = name.split(".", maxsplit=1) + return reserved_node_id, name_ + + return self._node_id, name def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: draft_vars = [] diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 8f323ebb8b..5fca444723 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -22,10 +22,10 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.entities import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 66976058c0..bef99458be 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -12,48 +12,52 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams, WorkflowNodeExecution -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file import File -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.repositories.human_input_form_repository import FormCreateParams -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import load_into_variable_pool -from dify_graph.variables import VariableBase -from dify_graph.variables.input_entities import VariableEntityType -from dify_graph.variables.variables import Variable +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings +from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType @@ -82,6 +86,8 @@ from .human_input_delivery_test_service import ( from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService from .workflow_restore import apply_published_workflow_snapshot_to_draft +_file_access_controller = DatabaseFileAccessController() + class WorkflowService: """ @@ -486,13 +492,15 @@ class WorkflowService: :raises ValueError: If the model configuration is invalid or credentials fail policy checks """ try: - from core.model_manager import ModelManager - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType + + # Model instance resolution and provider status lookup must reuse the + # same request-scoped runtime so validation does not silently split + # provider discovery and credential reads across different caches. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) # Get model instance to validate provider+model combination - model_manager = ModelManager() - model_manager.get_model_instance( + assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -501,8 +509,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) target_model = None @@ -607,11 +614,10 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) @@ -765,6 +771,7 @@ class WorkflowService: user_id=account.id, user_inputs=user_inputs, workflow=draft_workflow, + node_id=node_id, # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. conversation_variables=[], node_type=node_type, @@ -772,11 +779,13 @@ class WorkflowService: ) else: - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=draft_workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -841,6 +850,13 @@ class WorkflowService: draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=None, + user_id=account.id, + ) + return workflow_node_execution def get_human_input_form_preview( @@ -895,7 +911,6 @@ class WorkflowService: node_id=node_id, node_title=node.title, resolved_default_values=resolved_default_values, - form_token=None, ) return human_input_required.model_dump(mode="json") @@ -995,17 +1010,20 @@ class WorkflowService: if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) + node_data = HumanInputNodeData.model_validate( + normalize_human_input_node_data_for_graph(node_config["data"]), + from_attributes=True, + ) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, ) if delivery_method is None: raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( + delivery_method = apply_dify_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id, + actor_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1055,7 +1073,7 @@ class WorkflowService: node_data: HumanInputNodeData, delivery_method_id: str, ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: + for method in parse_human_input_delivery_methods(node_data): if str(method.id) == delivery_method_id: return method return None @@ -1070,9 +1088,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( - app_id=app_model.id, workflow_execution_id=None, node_id=node_id, form_config=node_data, @@ -1138,7 +1155,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node @@ -1155,11 +1172,13 @@ class WorkflowService: draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -1419,10 +1438,10 @@ class WorkflowService: Raises: ValueError: If the node data format is invalid """ - from dify_graph.nodes.human_input.entities import HumanInputNodeData + from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(node_data) + HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1511,38 +1530,48 @@ def _setup_variable_pool( user_id: str, user_inputs: Mapping[str, Any], workflow: Workflow, + node_id: str, node_type: NodeType, conversation_id: str, conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if is_start_node_type(node_type): - system_variable = SystemVariable( - user_id=user_id, - app_id=workflow.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=workflow.id, - files=files or [], - workflow_execution_id=str(uuid.uuid4()), - ) + system_variable_values: dict[str, Any] = { + "user_id": user_id, + "app_id": workflow.app_id, + "timestamp": int(naive_utc_now().timestamp()), + "workflow_id": workflow.id, + "files": files or [], + "workflow_execution_id": str(uuid.uuid4()), + } - # Only add chatflow-specific variables for non-workflow types + # Only add chatflow-specific variables for non-workflow types. if workflow.type != WorkflowType.WORKFLOW: - system_variable.query = query - system_variable.conversation_id = conversation_id - system_variable.dialogue_count = 1 + system_variable_values.update( + { + "query": query, + "conversation_id": conversation_id, + "dialogue_count": 1, + } + ) + + system_variable = build_system_variables(system_variable_values) else: - system_variable = SystemVariable.default() + system_variable = default_system_variables() # init variable pool - variable_pool = VariablePool( - system_variables=system_variable, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=cast(list[Variable], conversation_variables), # + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variable, + environment_variables=workflow.environment_variables, + conversation_variables=cast(list[Variable], conversation_variables), + ), ) + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1567,7 +1596,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): raise ValueError(f"expected dict for file object, got {type(value)}") - return build_from_mapping(mapping=value, tenant_id=tenant_id) + return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): raise ValueError(f"expected list for file list object, got {type(value)}") @@ -1575,6 +1604,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia return [] if not isinstance(value[0], dict): raise ValueError(f"expected dict for first element in the file list, got {type(value)}") - return build_from_mappings(mappings=value, tenant_id=tenant_id) + return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2..dafa36cc34 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf03454..c734e1321b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af95..c9aa8fadb7 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b00..41cf7ccbf6 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d62..2c07fe0f31 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e74..f41da1d373 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 174aa50343..458099d99e 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -21,8 +21,8 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -239,13 +239,18 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode + response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], + workflow_run_id: str, + app_mode: AppMode, ) -> None: topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) for event in response_stream: try: - payload = json.dumps(event) - except TypeError: + if isinstance(event, BaseModel): + payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) + else: + payload = json.dumps(event, ensure_ascii=False, default=str) + except (TypeError, ValueError): logger.exception("error while encoding event") continue diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index d247cf5cf7..6365400dd1 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -21,8 +21,8 @@ from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee00919..ed8a24b336 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,9 +11,10 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -109,7 +110,7 @@ def batch_create_segment_to_index_task( df = pd.read_csv(file_path) content = [] for _, row in df.iterrows(): - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: data = {"content": row.iloc[0], "answer": row.iloc[1]} else: data = {"content": row.iloc[0]} @@ -119,8 +120,8 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": - model_manager = ModelManager() + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=dataset_config["tenant_id"]) embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], provider=dataset_config["embedding_model_provider"], @@ -159,7 +160,7 @@ def batch_create_segment_to_index_task( status="completed", completed_at=naive_utc_now(), ) - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: segment_document.answer = segment["answer"] segment_document.word_count += len(segment["answer"]) word_count_change += segment_document.word_count diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e05d63426c..23a80fa106 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,6 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -126,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status @@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) if ( document.indexing_status == IndexingStatus.COMPLETED - and document.doc_form != "qa_model" + and document.doc_form != IndexStructureType.QA_INDEX and document.need_summary is True ): try: diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 0000000000..7d5ea7c0a5 --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc..e3d82d2851 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index dd3b6a4530..fd743205a1 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -7,10 +7,10 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d241783359..f8ae3f4b6e 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,10 +11,10 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from dify_graph.runtime import GraphRuntimeState, VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7..c95b8db078 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,17 +39,36 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.exception("Enterprise trace failed for app_id: %s", app_id) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: - logger.info("error:\n\n\n%s\n\n\n\n", e) + logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) - logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 39c2f4103e..6f490ab7ea 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,6 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -52,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " @@ -106,7 +107,7 @@ def regenerate_summary_index_task( ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived - DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .all() @@ -209,7 +210,7 @@ def regenerate_summary_index_task( for dataset_document in dataset_documents: # Skip qa_model documents - if dataset_document.doc_form == "qa_model": + if dataset_document.doc_form == IndexStructureType.QA_INDEX: continue try: diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 3a9a663759..07aeeb1e35 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -27,8 +27,8 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.enums import WorkflowExecutionStatus from enums.quota_type import QuotaType, unlimited +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, @@ -179,7 +179,7 @@ def _record_trigger_failure_log( app_id=workflow.app_id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=created_by_role, created_by=created_by, ) diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index f41118e592..ae1c2991c9 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -12,8 +12,8 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from graphon.entities.workflow_execution import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 466ef6c858..b823ce3961 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -12,10 +12,10 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -125,7 +125,7 @@ def _create_node_execution_from_domain( else: node_execution.execution_metadata = "{}" - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.created_by_role = creator_user_role @@ -159,7 +159,7 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode node_execution.execution_metadata = "{}" # Update other fields - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 4fdbb7d9f3..a876b0c4aa 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -2,7 +2,7 @@ from collections.abc import Generator from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from dify_graph.node_events import StreamCompletedEvent +from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3e79792b5b..b2de11b068 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,7 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index db4bbc1ca1..878d9b24df 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -6,10 +6,11 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -192,19 +197,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -313,7 +315,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -337,7 +339,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -364,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 1d7b835fd2..a942690cbd 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -13,6 +13,7 @@ from unittest.mock import patch import pytest from extensions.ext_redis import redis_client +from models.enums import ApiTokenType from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken @@ -279,7 +280,7 @@ class TestEndToEndCacheFlow: test_token = ApiToken() test_token.id = "test-e2e-id" test_token.token = test_token_value - test_token.type = test_scope + test_token.type = ApiTokenType.APP test_token.app_id = "test-app" test_token.tenant_id = "test-tenant" test_token.last_used_at = None diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 4e184c93fd..c4146d5ccd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -8,23 +8,23 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.model import PluginModelClient # import monkeypatch -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ( AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, ModelType, ) -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity class MockModelClass(PluginModelClient): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 9d3a869691..0b21ff1d2a 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,15 +6,15 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import StringVariable +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index bc83c6cc12..f6f4cf260b 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -5,8 +5,8 @@ import pytest from sqlalchemy import delete from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -192,7 +192,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -423,7 +423,7 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 5b0f86fed1..a9a2617bae 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType @@ -15,7 +15,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3a2b6b866..7573e00872 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -6,13 +6,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f885f69e55..17ea7de881 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,12 +9,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.graph import Graph -from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -54,7 +55,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -81,6 +82,7 @@ def init_http_node(config: dict): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) return node @@ -189,20 +191,20 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from dify_graph.enums import BuiltinNodeTypes - from dify_graph.nodes.http_request.entities import ( + from core.workflow.system_variables import build_system_variables + from graphon.enums import BuiltinNodeTypes + from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) - from dify_graph.nodes.http_request.exc import AuthorizationConfigError - from dify_graph.nodes.http_request.executor import Executor - from dify_graph.runtime import VariablePool - from dify_graph.system_variable import SystemVariable + from graphon.nodes.http_request.exc import AuthorizationConfigError + from graphon.nodes.http_request.executor import Executor + from graphon.runtime import VariablePool # Create variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="test", files=[]), + system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -700,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock): # Create independent variable pool for this test only variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -728,6 +730,7 @@ def test_nested_object_variable_selector(setup_http_mock): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d628348f1e..fa5d63cfbf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -7,14 +7,16 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import LLMNode +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.protocols import HttpClientProtocol +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -51,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", app_id=app_id, workflow_id=workflow_id, @@ -66,6 +68,11 @@ def init_llm_node(config: dict) -> LLMNode: variable_pool.add(["abc", "output"], "sunny") graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol) + prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [ + message.model_dump(mode="json") for message in prompt_messages + ] + llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( id=str(uuid.uuid4()), @@ -75,7 +82,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), - template_renderer=MagicMock(spec=TemplateRenderer), + llm_file_saver=llm_file_saver, + prompt_message_serializer=prompt_message_serializer, http_client=MagicMock(spec=HttpClientProtocol), ) @@ -115,8 +123,8 @@ def test_execute_llm(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -159,8 +167,8 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(*_args, **_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + def mock_fetch_prompt_messages_1(**_kwargs): + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), @@ -231,8 +239,8 @@ def test_execute_llm_with_jinja2(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -276,7 +284,7 @@ def test_execute_llm_with_jinja2(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 62d9af0196..367b5bbc11 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyPromptMessageSerializer +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -56,7 +57,7 @@ def init_parameter_extractor_node(config: dict, memory=None): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), user_inputs={}, @@ -77,6 +78,7 @@ def init_parameter_extractor_node(config: dict, memory=None): model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), memory=memory, + prompt_message_serializer=DifyPromptMessageSerializer(), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 7bb4f905c3..9e3e1a47e3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -3,12 +3,12 @@ import uuid from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -66,7 +66,7 @@ def test_execute_template_transform(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -90,7 +90,7 @@ def test_execute_template_transform(): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - template_renderer=_SimpleJinja2Renderer(), + jinja2_template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a6717ada31..f9ec51ee10 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.nodes.tool.tool_node import ToolNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +41,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -64,11 +65,12 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=DifyToolNodeRuntime(init_params.run_context), ) return node -def test_tool_variable_invoke(): +def test_tool_variable_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", @@ -103,7 +105,7 @@ def test_tool_variable_invoke(): assert item.node_run_result.outputs.get("text") is not None -def test_tool_mixed_invoke(): +def test_tool_mixed_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef0ca4232d..48bf3ca446 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -33,6 +33,9 @@ from extensions.ext_database import db logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +DEFAULT_SANDBOX_TEST_IMAGE = "langgenius/dify-sandbox:0.2.14" +SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" + class _CloserProtocol(Protocol): """_Closer is any type which implement the close() method.""" @@ -163,11 +166,11 @@ class DifyTestContainers: wait_for_logs(self.redis, "Ready to accept connections", timeout=30) logger.info("Redis container is ready and accepting connections") - # Start Dify Sandbox container for code execution environment - # Dify Sandbox provides a secure environment for executing user code - # Use pinned version 0.2.12 to match production docker-compose configuration + # Start Dify Sandbox container for code execution environment. + # Default to the production-pinned image while allowing local overrides for debugging. logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network) + sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE) + self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -177,7 +180,12 @@ class DifyTestContainers: sandbox_port = self.dify_sandbox.get_exposed_port(8194) os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" - logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + logger.info( + "Dify Sandbox container started successfully - Image: %s Host: %s, Port: %s", + sandbox_image, + sandbox_host, + sandbox_port, + ) # Wait for Dify Sandbox to be ready logger.info("Waiting for Dify Sandbox to be ready to accept connections...") @@ -187,7 +195,7 @@ class DifyTestContainers: # Start Dify Plugin Daemon container for plugin management # Dify Plugin Daemon provides plugin lifecycle management and execution logger.info("Initializing Dify Plugin Daemon container...") - self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network( + self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.3-local").with_network( self.network ) self.dify_plugin_daemon.with_exposed_ports(5002) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 4f606dccb8..5b51510388 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py new file mode 100644 index 0000000000..6b51ec98bc --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -0,0 +1,342 @@ +"""Authenticated controller integration tests for console message APIs.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload +from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation: + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Test Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + account_id: str, + *, + created_at_offset_seconds: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=1, + message_unit_price=Decimal("0.0001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=1, + answer_unit_price=Decimal("0.0001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal("0.0002"), + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +class TestMessageValidators: + def test_chat_messages_query_validators(self) -> None: + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self) -> None: + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self) -> None: + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +def test_chat_message_list_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": str(uuid4())}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_chat_message_list_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0) + second = _create_message( + db_session_with_containers, + app.id, + conversation.id, + account.id, + created_at_offset_seconds=1, + ) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": conversation.id, "limit": 1}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["limit"] == 1 + assert payload["has_more"] is True + assert len(payload["data"]) == 1 + assert payload["data"][0]["id"] == second.id + + +def test_message_feedback_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": str(uuid4()), "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_message_feedback_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": message.id, "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + feedback = db_session_with_containers.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == message.id) + ) + assert feedback is not None + assert feedback.rating == FeedbackRating.LIKE + assert feedback.from_account_id == account.id + + +def test_message_annotation_count( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + db_session_with_containers.add( + MessageAnnotation( + app_id=app.id, + conversation_id=conversation.id, + message_id=message.id, + question="Q", + content="A", + account_id=account.id, + ) + ) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/annotations/count", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"count": 1} + + +def test_message_suggested_questions_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"data": ["q1", "q2"]} + + +@pytest.mark.parametrize( + ("exc", "expected_status", "expected_code"), + [ + (MessageNotExistsError(), 404, "not_found"), + (ConversationNotExistsError(), 404, "not_found"), + (ProviderTokenNotInitError(), 400, "provider_not_initialize"), + (QuotaExceededError(), 400, "provider_quota_exceeded"), + (ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"), + (SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"), + (Exception(), 500, "internal_server_error"), + ], +) +def test_message_suggested_questions_errors( + exc: Exception, + expected_status: int, + expected_code: str, + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + side_effect=exc, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == expected_status + payload = response.get_json() + assert payload is not None + assert payload["code"] == expected_code + + +def test_message_feedback_export_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/feedbacks/export", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"exported": True} + + +def test_message_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/messages/{message.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == message.id diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..963cfe53e5 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,334 @@ +"""Controller integration tests for console statistic routes.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageFeedback +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation( + db_session: Session, + app_id: str, + account_id: str, + *, + mode: AppMode, + created_at_offset_days: int = 0, +) -> Conversation: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Stats Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + *, + from_account_id: str | None, + from_end_user_id: str | None = None, + message_tokens: int = 1, + answer_tokens: int = 1, + total_price: Decimal = Decimal("0.01"), + provider_response_latency: float = 1.0, + created_at_offset_days: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=message_tokens, + message_unit_price=Decimal("0.001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=answer_tokens, + answer_unit_price=Decimal("0.001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=provider_response_latency, + total_price=total_price, + currency="USD", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=from_end_user_id, + from_account_id=from_account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +def _create_like_feedback( + db_session: Session, + app_id: str, + conversation_id: str, + message_id: str, + account_id: str, +) -> None: + db_session.add( + MessageFeedback( + app_id=app_id, + conversation_id=conversation_id, + message_id=message_id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + from_account_id=account_id, + ) + ) + db_session.commit() + + +def test_daily_message_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["message_count"] == 1 + + +def test_daily_conversation_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-conversations", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["conversation_count"] == 1 + + +def test_daily_terminals_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=None, + from_end_user_id=str(uuid4()), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-end-users", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["terminal_count"] == 1 + + +def test_daily_token_cost_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + message_tokens=40, + answer_tokens=60, + total_price=Decimal("0.02"), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/token-costs", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["data"][0]["token_count"] == 100 + assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02") + + +def test_average_session_interaction_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-session-interactions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["interactions"] == 2.0 + + +def test_user_satisfaction_rate_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + for _ in range(9): + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["rate"] == 100.0 + + +def test_average_response_time_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + provider_response_latency=1.234, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-response-time", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["latency"] == 1234.0 + + +def test_tokens_per_second_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + answer_tokens=31, + provider_response_latency=2.0, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/tokens-per-second", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["tps"] == 15.5 + + +def test_invalid_time_range( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "Invalid time" + + +def test_time_range_params_passed( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + import datetime + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + start = datetime.datetime.now() + end = datetime.datetime.now() + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse: + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + mock_parse.assert_called_once_with("something", "something", "UTC") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..290be87697 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,415 @@ +"""Authenticated controller integration tests for workflow draft variable APIs.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from factories.variable_factory import segment_to_variable +from graphon.variables.segments import StringSegment +from models import Workflow +from models.model import AppMode +from models.workflow import WorkflowDraftVariable +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_draft_workflow( + db_session: Session, + app_id: str, + tenant_id: str, + account_id: str, + *, + environment_variables: list | None = None, + conversation_variables: list | None = None, +) -> Workflow: + workflow = Workflow.new( + tenant_id=tenant_id, + app_id=app_id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=account_id, + environment_variables=environment_variables or [], + conversation_variables=conversation_variables or [], + rag_pipeline_variables=[], + ) + db_session.add(workflow) + db_session.commit() + return workflow + + +def _create_node_variable( + db_session: Session, + app_id: str, + user_id: str, + *, + node_id: str = "node_1", + name: str = "test_var", +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + user_id=user_id, + node_id=node_id, + name=name, + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + visible=True, + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _create_system_variable( + db_session: Session, app_id: str, user_id: str, name: str = "query" +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + user_id=user_id, + name=name, + value=StringSegment(value="system-value"), + node_execution_id=str(uuid.uuid4()), + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _build_environment_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, name], + name=name, + description=f"Environment variable {name}", + ) + + +def _build_conversation_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[CONVERSATION_VARIABLE_NODE_ID, name], + name=name, + description=f"Conversation variable {name}", + ) + + +def test_workflow_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"items": [], "total": 0} + + +def test_workflow_variable_collection_get_not_exist( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "draft_workflow_not_exist" + + +def test_workflow_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_node_variable(db_session_with_containers, app.id, account.id) + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var") + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + remaining = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + ) + ).all() + assert remaining == [] + + +def test_node_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other") + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [node_variable.id] + + +def test_node_variable_collection_get_invalid_node_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "invalid_param" + + +def test_node_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456") + target_id = target.id + untouched_id = untouched.id + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id)) + is None + ) + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id)) + is not None + ) + + +def test_variable_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "test_var" + + +def test_variable_api_get_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_variable_api_patch_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.patch( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + json={"name": "renamed_var"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "renamed_var" + + refreshed = db_session_with_containers.scalar( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id) + ) + assert refreshed is not None + assert refreshed.name == "renamed_var" + + +def test_variable_api_delete_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_variable_reset_api_put_success_returns_no_content_without_execution( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.put( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_conversation_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + conversation_variables=[_build_conversation_variable("session_name", "Alice")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/conversation-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["name"] for item in payload["items"]] == ["session_name"] + + created = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID, + ) + ).all() + assert len(created) == 1 + + +def test_system_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + variable = _create_system_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/system-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [variable.id] + + +def test_environment_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + environment_variables=[_build_environment_variable("api_key", "secret-value")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/environment-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["items"][0]["name"] == "api_key" + assert payload["items"][0]["value"] == "secret-value" diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..00309c25d6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,131 @@ +"""Controller integration tests for API key data source auth routes.""" + +import json +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.source import DataSourceApiKeyAuthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_api_key_auth_data_source( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert len(payload["sources"]) == 1 + assert payload["sources"][0]["provider"] == "custom_provider" + + +def test_get_api_key_auth_data_source_empty( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"sources": []} + + +def test_create_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + +def test_create_binding_failure( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth", + side_effect=ValueError("Invalid structure"), + ), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 500 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "auth_failed" + assert payload["message"] == "Invalid structure" + + +def test_delete_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{binding.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id) + ) + is None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 0000000000..81b5423261 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,120 @@ +"""Controller integration tests for console OAuth data source routes.""" + +from unittest.mock import MagicMock, patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.source import DataSourceOauthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_oauth_url_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + provider = MagicMock() + provider.get_authorization_url.return_value = "http://oauth.provider/auth" + + with ( + patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}), + patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None), + ): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/notion", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert tenant.id == account.current_tenant_id + assert response.status_code == 200 + assert response.get_json() == {"data": "http://oauth.provider/auth"} + provider.get_authorization_url.assert_called_once() + + +def test_get_oauth_url_invalid_provider( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/unknown_provider", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + +def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion") + + assert response.status_code == 302 + assert "error=Access%20denied" in response.location + + +def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None: + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123") + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.get_access_token.assert_called_once_with("auth_code_123") + + +def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid code"} + + +def test_sync_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceOauthBinding( + tenant_id=tenant.id, + access_token="test-access-token", + provider="notion", + source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []}, + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get( + f"/console/api/oauth/data-source/notion/{binding.id}/sync", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.sync_data_source.assert_called_once_with(binding.id) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_email_register.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 724c80f18c..879c337319 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for email register controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestEmailRegisterSendEmailApi: - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") @@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - mock_session_cls, app, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False mock_account = MagicMock() - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = mock_account feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), @@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi: assert response == {"result": "success", "data": "token-123"} mock_is_freeze.assert_called_once_with("invitee@example.com") mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) mock_extract_ip.assert_called_once() mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") @@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi: feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -114,7 +107,6 @@ class TestEmailRegisterResetApi: @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") @patch("controllers.console.auth.email_register.AccountService.login") @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") @@ -125,7 +117,6 @@ class TestEmailRegisterResetApi: mock_get_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_create_account, mock_login, mock_reset_login_rate, @@ -136,14 +127,10 @@ class TestEmailRegisterResetApi: token_pair = MagicMock() token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} mock_login.return_value = token_pair - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = None feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -159,19 +146,19 @@ class TestEmailRegisterResetApi: mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_forgot_password.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 8403777dc9..7b7393dade 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for forgot password controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestForgotPasswordSendEmailApi: - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - mock_session_cls, app, ): mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_email.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) controller_features = SimpleNamespace(is_allow_register=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( "controllers.console.auth.forgot_password.FeatureService.get_system_features", return_value=controller_features, @@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_send_email.assert_called_once_with( account=mock_account, email="user@example.com", @@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @@ -126,7 +120,6 @@ class TestForgotPasswordResetApi: mock_get_reset_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_update_account, app, ): @@ -134,12 +127,8 @@ class TestForgotPasswordResetApi: mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - wraps_features = SimpleNamespace(enable_email_password_login=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): @@ -157,20 +146,22 @@ class TestForgotPasswordResetApi: assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_update_account.assert_called_once() -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" + from unittest.mock import MagicMock + mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py similarity index 92% rename from api/tests/unit_tests/controllers/console/auth/test_oauth.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 6345c2ab23..a2f1328579 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -1,7 +1,10 @@ +"""Testcontainers integration tests for OAuth controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.mark.parametrize( ("github_config", "google_config", "expected_github", "expected_google"), @@ -64,10 +65,8 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_oauth_provider(self): @@ -131,10 +130,8 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def oauth_setup(self): @@ -190,15 +187,8 @@ class TestOAuthCallback: (KeyError("Missing key"), "OAuth process failed"), ], ) - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions( - self, mock_get_providers, mock_db, resource, app, exception, expected_error - ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - + def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): # Import the real requests module to create a proper exception import httpx @@ -258,7 +248,6 @@ class TestOAuthCallback: ) @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") @@ -269,7 +258,6 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - mock_db, mock_tenant_service, mock_account_service, resource, @@ -278,10 +266,6 @@ class TestOAuthCallback: account_status, expected_redirect, ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - mock_db.session.commit = MagicMock() mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -306,14 +290,12 @@ class TestOAuthCallback: @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") def test_should_activate_pending_account( self, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -338,12 +320,10 @@ class TestOAuthCallback: assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None - mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.redirect") @@ -352,7 +332,6 @@ class TestOAuthCallback: mock_redirect, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -414,6 +393,10 @@ class TestOAuthCallback: class TestAccountGeneration: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def user_info(self): return OAuthUserInfo(id="123", name="Test User", email="test@example.com") @@ -425,15 +408,10 @@ class TestAccountGeneration: return account @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.console.auth.oauth.Session") @patch("controllers.console.auth.oauth.Account") - @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account + self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account ): - # Mock db.engine for Session creation - mock_db.engine = MagicMock() - # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) @@ -443,15 +421,14 @@ class TestAccountGeneration: # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None - mock_session_instance = MagicMock() - mock_session.return_value.__enter__.return_value = mock_session_instance mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account - mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + mock_get_account.assert_called_once() - def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() first_result = MagicMock() first_result.scalar_one_or_none.return_value = None @@ -462,7 +439,7 @@ class TestAccountGeneration: result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert result == expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( @@ -478,10 +455,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_handle_account_generation_scenarios( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, @@ -519,10 +494,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_register_with_lowercase_email( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..2ef27133d8 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,365 @@ +"""Controller integration tests for console OAuth server routes.""" + +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + ensure_dify_setup, +) + + +def _build_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon_url", + client_id="test_client_id", + client_secret="test_secret", + app_label={"en-US": "Test App"}, + redirect_uris=["http://localhost/callback"], + scope="read,write", + ) + + +def test_oauth_provider_successful_post( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["app_icon"] == "icon_url" + assert payload["app_label"] == {"en-US": "Test App"} + assert payload["scope"] == "read,write" + + +def test_oauth_provider_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert "redirect_uri is invalid" in payload["message"] + + +def test_oauth_provider_invalid_client_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert "client_id is invalid" in payload["message"] + + +def test_oauth_authorize_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="auth_code_123", + ) as mock_sign, + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/authorize", + json={"client_id": "test_client_id"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"code": "auth_code_123"} + mock_sign.assert_called_once_with("test_client_id", account.id) + + +def test_oauth_token_authorization_code_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("access_123", "refresh_123"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "access_123", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "refresh_123", + } + + +def test_oauth_token_authorization_code_grant_missing_code( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "code is required" + + +def test_oauth_token_authorization_code_grant_invalid_secret( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "client_secret is invalid" + + +def test_oauth_token_authorization_code_grant_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "redirect_uri is invalid" + + +def test_oauth_token_refresh_token_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("new_access", "new_refresh"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "new_access", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "new_refresh", + } + + +def test_oauth_token_refresh_token_grant_missing_token( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "refresh_token is required" + + +def test_oauth_token_invalid_grant_type( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "invalid_grant"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "invalid grant_type" + + +def test_oauth_account_successful_retrieval( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account.avatar = "avatar_url" + db_session_with_containers.commit() + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token", + return_value=account, + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "name": "Test User", + "email": account.email, + "avatar": "avatar_url", + "interface_language": "en-US", + "timezone": "UTC", + } + + +def test_oauth_account_missing_authorization_header( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Authorization header is required"} + + +def test_oauth_account_invalid_authorization_header_format( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Invalid Authorization header format"} diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528e..8f9db287e3 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py new file mode 100644 index 0000000000..9e2084f393 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -0,0 +1,85 @@ +"""Shared helpers for authenticated console controller integration tests.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HEADER_NAME_CSRF_TOKEN +from libs.datetime_utils import naive_utc_now +from libs.token import _real_cookie_name, generate_csrf_token +from models import Account, DifySetup, Tenant, TenantAccountJoin +from models.account import AccountStatus, TenantAccountRole +from models.model import App, AppMode +from services.account_service import AccountService + + +def ensure_dify_setup(db_session: Session) -> None: + """Create a setup marker once so setup-protected console routes can be exercised.""" + if db_session.scalar(select(DifySetup).limit(1)) is not None: + return + + db_session.add(DifySetup(version=dify_config.project.version)) + db_session.commit() + + +def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: + """Create an initialized owner account with a current tenant.""" + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="Test User", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.initialized_at = naive_utc_now() + db_session.add(account) + db_session.commit() + + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + + db_session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + ) + db_session.commit() + + account.set_tenant_id(tenant.id) + account.timezone = "UTC" + db_session.commit() + + ensure_dify_setup(db_session) + return account, tenant + + +def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App: + """Create a minimal app row that can be loaded by get_app_model.""" + app = App( + tenant_id=tenant_id, + name="Test App", + mode=mode, + enable_site=True, + enable_api=True, + created_by=account_id, + ) + db_session.add(app) + db_session.commit() + return app + + +def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]: + """Attach console auth cookies/headers for endpoints guarded by login_required.""" + access_token = AccountService.get_account_jwt_token(account) + csrf_token = generate_csrf_token(account.id) + test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost") + return { + "Authorization": f"Bearer {access_token}", + HEADER_NAME_CSRF_TOKEN: csrf_token, + } diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 96fb7ea293..b8840c4ba8 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,17 +31,18 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.variable_pool import SystemVariable, VariablePool +from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from graphon.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper +from graphon.runtime.variable_pool import VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -212,7 +213,7 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) @@ -544,7 +545,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from dify_graph.graph_events.graph import ( + from graphon.graph_events.graph import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 781e297fa4..00d7496a40 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -38,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration: name=f"Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived @@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, @@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=status, # Not completed enabled=True, archived=False, @@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -459,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) @@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 9d0fad4b12..e0c58f0f5c 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,20 +7,17 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.nodes.human_input.entities import ( +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.repositories.human_input_form_repository import FormCreateParams +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -68,7 +65,6 @@ def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCre user_actions=[UserAction(id="approve", title="Approve")], ) return FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=form_config, @@ -84,7 +80,7 @@ def _build_email_delivery( ) -> EmailDeliveryMethod: return EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + recipients=EmailRecipients(include_bound_group=whole_workspace, items=recipients), subject="Approval Needed", body="Please review", ) @@ -100,7 +96,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,13 +125,13 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[ _build_email_delivery( whole_workspace=False, recipients=[ - MemberRecipient(user_id=members[0].id), + MemberRecipient(reference_id=members[0].id), ExternalRecipient(email="external@example.com"), ], ) @@ -173,10 +169,9 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( @@ -210,9 +205,8 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 9733735df3..ae8c0716a4 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -10,22 +10,23 @@ from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole @@ -39,7 +40,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -52,7 +53,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -66,7 +67,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, workflow_id=workflow_id, @@ -120,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 8e70fc0bb0..2e207ddc67 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -6,10 +6,11 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -193,19 +198,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -314,7 +316,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -338,7 +340,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -365,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index fb8d1808f9..2fd289dfbc 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -1,11 +1,12 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal from uuid import uuid4 -from dify_graph.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent @@ -117,7 +118,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: inputs=[], user_actions=[UserAction(id=action_id, title=action_text)], rendered_content="Rendered block", - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), node_title=node_title, display_in_ui=True, ) @@ -129,7 +130,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: form_definition=form_definition.model_dump_json(), rendered_content="Rendered block", status=HumanInputFormStatus.SUBMITTED, - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), selected_action_id=action_id, ) db_session.add(form) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py new file mode 100644 index 0000000000..a79208f649 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py @@ -0,0 +1,227 @@ +""" +Integration tests for Redis Streams broadcast channel implementation using TestContainers. + +This suite focuses on the semantics that differ from Redis Pub/Sub: +- Every active subscription should receive each newly published message. +- Each subscription should only observe messages published after its listener starts. +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + +class TestRedisStreamsBroadcastChannelIntegration: + """Integration tests for Redis Streams broadcast channel with a real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a StreamsBroadcastChannel instance with a real Redis client.""" + return StreamsBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_streams_topic_{uuid.uuid4()}" + + @staticmethod + def _start_subscription(subscription: Subscription) -> None: + """Start the background listener and confirm the subscription queue is empty.""" + assert subscription.receive(timeout=0.05) is None + + @staticmethod + def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes: + """Poll until a message is received or the timeout expires.""" + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + message = subscription.receive(timeout=0.1) + if message is not None: + return message + pytest.fail("Timed out waiting for a message") + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None: + """Closing an active subscription should terminate the iterator cleanly.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume() -> list[bytes]: + messages: list[bytes] = [] + consuming_event.set() + for message in subscription: + messages.append(message) + return messages + + with ThreadPoolExecutor(max_workers=1) as executor: + consumer_future = executor.submit(consume) + assert consuming_event.wait(timeout=1.0) + subscription.close() + assert consumer_future.result(timeout=2.0) == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None: + """A producer should publish a message that a live subscription can consume.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + producer = topic.as_producer() + subscription = topic.subscribe() + message = b"hello streams" + + try: + self._start_subscription(subscription) + producer.publish(message) + + assert self._receive_message(subscription) == message + assert subscription.receive(timeout=0.1) is None + finally: + subscription.close() + + def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None: + """Each active subscription should receive the same newly published message.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscriptions = [topic.subscribe() for _ in range(3)] + new_message = b"message-visible-to-every-subscriber" + + try: + for subscription in subscriptions: + self._start_subscription(subscription) + + topic.publish(new_message) + + for subscription in subscriptions: + assert self._receive_message(subscription) == new_message + assert subscription.receive(timeout=0.1) is None + finally: + for subscription in subscriptions: + subscription.close() + + def test_each_subscription_only_receives_messages_published_after_it_starts( + self, + broadcast_channel: BroadcastChannel, + ) -> None: + """A late subscription should not replay messages that existed before its listener started.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + first_subscription = topic.subscribe() + second_subscription = topic.subscribe() + message_before_any_subscription = b"before-any-subscription" + message_after_first_subscription = b"after-first-subscription" + message_after_second_subscription = b"after-second-subscription" + + try: + topic.publish(message_before_any_subscription) + + self._start_subscription(first_subscription) + topic.publish(message_after_first_subscription) + + assert self._receive_message(first_subscription) == message_after_first_subscription + assert first_subscription.receive(timeout=0.1) is None + + self._start_subscription(second_subscription) + topic.publish(message_after_second_subscription) + + assert self._receive_message(first_subscription) == message_after_second_subscription + assert self._receive_message(second_subscription) == message_after_second_subscription + assert first_subscription.receive(timeout=0.1) is None + assert second_subscription.receive(timeout=0.1) is None + finally: + first_subscription.close() + second_subscription.close() + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None: + """Messages from different topics should remain isolated.""" + topic1 = broadcast_channel.topic(self._get_test_topic_name()) + topic2 = broadcast_channel.topic(self._get_test_topic_name()) + message1 = b"message-for-topic-1" + message2 = b"message-for-topic-2" + + def consume_single_message(topic: Topic) -> bytes: + subscription = topic.subscribe() + try: + self._start_subscription(subscription) + return self._receive_message(subscription) + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + consumer1_future = executor.submit(consume_single_message, topic1) + consumer2_future = executor.submit(consume_single_message, topic2) + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + assert consumer1_future.result(timeout=5.0) == message1 + assert consumer2_future.result(timeout=5.0) == message2 + + def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None: + """Concurrent producers should not lose messages for a live subscription.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + producer_count = 4 + messages_per_producer = 4 + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def produce_messages(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced: set[bytes] = set() + for message_idx in range(messages_per_producer): + payload = f"producer-{producer_idx}-message-{message_idx}".encode() + produced.add(payload) + producer.publish(payload) + time.sleep(0.001) + return produced + + def consume_messages() -> set[bytes]: + received: set[bytes] = set() + try: + self._start_subscription(subscription) + consumer_ready.set() + while len(received) < expected_total: + message = subscription.receive(timeout=0.2) + if message is not None: + received.add(message) + return received + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consume_messages) + assert consumer_ready.wait(timeout=2.0) + + producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)] + expected_messages: set[bytes] = set() + for future in as_completed(producer_futures, timeout=10.0): + expected_messages.update(future.result()) + + assert consumer_future.result(timeout=10.0) == expected_messages + + def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None: + """Calling receive on a closed subscription should raise SubscriptionClosedError.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + + self._start_subscription(subscription) + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.1) diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py deleted file mode 100644 index c9058626d1..0000000000 --- a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.orm import sessionmaker - -from extensions.ext_database import db -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) - ) - - results = repository.get_by_message_ids([fixture.message.id]) - - assert len(results) == 1 - assert len(results[0]) == 1 - content = results[0][0] - assert content.submitted is True - assert content.form_submission_data is not None - assert content.form_submission_data.action_id == fixture.action_id - assert content.form_submission_data.action_text == fixture.action_text - assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 458862b0ec..641399c7f9 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ from uuid import uuid4 from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index c3ed79656f..cb00752b35 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,7 +2,6 @@ from __future__ import annotations -import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -12,22 +11,20 @@ import pytest from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( - BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, - RecipientType, ) -from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, @@ -218,7 +215,7 @@ class TestDeleteRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -278,7 +275,7 @@ class TestCountRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -636,12 +633,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_prefers_backstage_token_when_available( + def test_builds_reason_from_form_definition( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Use backstage token when multiple recipient types may exist.""" + """Build the graph pause reason from the stored form definition.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -668,25 +665,6 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() - delivery = HumanInputDelivery( - form_id=form_model.id, - delivery_method_type=DeliveryMethodType.WEBAPP, - channel_payload="{}", - ) - db_session_with_containers.add(delivery) - db_session_with_containers.flush() - - access_token = secrets.token_urlsafe(8) - recipient = HumanInputFormRecipient( - form_id=form_model.id, - delivery_id=delivery.id, - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - db_session_with_containers.add(recipient) - db_session_with_containers.flush() - # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -716,13 +694,12 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) - db_session_with_containers.refresh(recipient) - reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + reason = _build_human_input_required_reason(reason_model, form_model) assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token assert reason.node_title == "Ask Name" assert reason.form_content == "content" assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" + assert reason.resolved_default_values == {"name": "Alice"} diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 0000000000..aaf9a85d60 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,408 @@ +"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers. + +Part of #32454 — replaces the mock-based unit tests with real database interactions. +""" + +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from datetime import timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource, InvokeFrom +from models.execution_extra_content import ExecutionExtraContent, HumanInputContent +from models.human_input import ( + ConsoleRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) +from models.model import App, Conversation, Message +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows. + + IDs are populated after flushing the base entities to the database. + """ + + tenant_id: str = "" + app_id: str = "" + user_id: str = "" + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(ExecutionExtraContent).where( + ExecutionExtraContent.workflow_run_id.in_( + select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id) + ) + ) + ) + session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id)) + session.execute(delete(Message).where(Message.app_id == scope.app_id)) + session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id)) + session.execute(delete(App).where(App.id == scope.app_id)) + session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id)) + session.execute(delete(Account).where(Account.id == scope.user_id)) + session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id)) + session.commit() + + +def _seed_base_entities(session: Session, scope: _TestScope) -> None: + """Create the base tenant, account, and app needed by tests.""" + tenant = Tenant(name="Test Tenant") + session.add(tenant) + session.flush() + scope.tenant_id = tenant.id + + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + scope.user_id = account.id + + tenant_join = TenantAccountJoin( + tenant_id=scope.tenant_id, + account_id=scope.user_id, + role=TenantAccountRole.OWNER, + current=True, + ) + session.add(tenant_join) + + app = App( + tenant_id=scope.tenant_id, + name="Test App", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=scope.user_id, + updated_by=scope.user_id, + ) + session.add(app) + session.flush() + scope.app_id = app.id + + +def _create_conversation(session: Session, scope: _TestScope) -> Conversation: + conversation = Conversation( + app_id=scope.app_id, + mode="chat", + name="Test Conversation", + summary="", + introduction="", + system_instruction="", + status="normal", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + from_end_user_id=None, + ) + conversation.inputs = {} + session.add(conversation) + session.flush() + return conversation + + +def _create_message( + session: Session, + scope: _TestScope, + conversation_id: str, + workflow_run_id: str, +) -> Message: + message = Message( + app_id=scope.app_id, + conversation_id=conversation_id, + inputs={}, + query="test query", + message={"messages": []}, + answer="test answer", + message_tokens=50, + message_unit_price=Decimal("0.001"), + answer_tokens=80, + answer_unit_price=Decimal("0.001"), + provider_response_latency=0.5, + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + workflow_run_id=workflow_run_id, + ) + session.add(message) + session.flush() + return message + + +def _create_submitted_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + action_id: str = "approve", + action_title: str = "Approve", + node_title: str = "Approval", +) -> HumanInputForm: + expiration_time = naive_utc_now() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_title)], + rendered_content="rendered", + expiration_time=expiration_time, + node_title=node_title, + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content=f"Rendered {action_title}", + status=HumanInputFormStatus.SUBMITTED, + expiration_time=expiration_time, + selected_action_id=action_id, + ) + session.add(form) + session.flush() + return form + + +def _create_waiting_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + default_values: dict | None = None, +) -> HumanInputForm: + expiration_time = naive_utc_now() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values=default_values or {"name": "John"}, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + session.add(form) + session.flush() + return form + + +def _create_human_input_content( + session: Session, + *, + workflow_run_id: str, + message_id: str, + form_id: str, +) -> HumanInputContent: + content = HumanInputContent.new( + workflow_run_id=workflow_run_id, + message_id=message_id, + form_id=form_id, + ) + session.add(content) + return content + + +def _create_recipient( + session: Session, + *, + form_id: str, + delivery_id: str, + recipient_type: RecipientType = RecipientType.CONSOLE, + access_token: str = "token-1", +) -> HumanInputFormRecipient: + payload = ConsoleRecipientPayload(account_id=None) + recipient = HumanInputFormRecipient( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=recipient_type, + recipient_payload=payload.model_dump_json(), + access_token=access_token, + ) + session.add(recipient) + return recipient + + +def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: + from core.workflow.human_input_compat import DeliveryMethodType + from models.human_input import ConsoleDeliveryPayload + + delivery = HumanInputDelivery( + form_id=form_id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload=ConsoleDeliveryPayload().model_dump_json(), + ) + session.add(delivery) + session.flush() + return delivery + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + _seed_base_entities(db_session_with_containers, scope) + db_session_with_containers.commit() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetByMessageIds: + """Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids.""" + + def test_groups_contents_by_message( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Submitted forms are correctly mapped and grouped by message ID.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_submitted_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + action_id="approve", + action_title="Approve", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg1.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg1.id, msg2.id]) + + assert len(result) == 2 + # msg1 has one submitted content + assert len(result[0]) == 1 + content = result[0][0] + assert content.submitted is True + assert content.workflow_run_id == workflow_run_id + assert content.form_submission_data is not None + assert content.form_submission_data.action_id == "approve" + assert content.form_submission_data.action_text == "Approve" + assert content.form_submission_data.rendered_content == "Rendered Approve" + assert content.form_submission_data.node_id == "node-id" + assert content.form_submission_data.node_title == "Approval" + # msg2 has no content + assert result[1] == [] + + def test_returns_unsubmitted_form_definition( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Waiting forms return full form_definition with resolved token and defaults.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_waiting_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + default_values={"name": "John"}, + ) + delivery = _create_delivery(db_session_with_containers, form_id=form.id) + _create_recipient( + db_session_with_containers, + form_id=form.id, + delivery_id=delivery.id, + access_token="token-1", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg.id]) + + assert len(result) == 1 + assert len(result[0]) == 1 + domain_content = result[0][0] + assert domain_content.submitted is False + assert domain_content.workflow_run_id == workflow_run_id + assert domain_content.form_definition is not None + form_def = domain_content.form_definition + assert form_def.form_id == form.id + assert form_def.node_id == "node-id" + assert form_def.node_title == "Approval" + assert form_def.form_content == "Rendered block" + assert form_def.display_in_ui is True + assert form_def.form_token == "token-1" + assert form_def.resolved_default_values == {"name": "John"} + assert form_def.expiration_time == int(form.expiration_time.timestamp()) + + def test_empty_message_ids_returns_empty_list( + self, + repository: SQLAlchemyExecutionExtraContentRepository, + ) -> None: + """Passing no message IDs returns an empty list without hitting the DB.""" + result = repository.get_by_message_ids([]) + assert result == [] diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index 1568d5d65c..d6f0657380 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -11,8 +11,8 @@ from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/auth/__init__.py b/api/tests/test_containers_integration_tests/services/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py new file mode 100644 index 0000000000..177fb95ff3 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class TestApiKeyAuthService: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def provider(self) -> str: + return "google" + + @pytest.fixture + def mock_credentials(self) -> dict: + return {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}} + + @pytest.fixture + def mock_args(self, category, provider, mock_credentials) -> dict: + return {"category": category, "provider": provider, "credentials": mock_credentials} + + def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=json.dumps(credentials, ensure_ascii=False) if credentials else None, + disabled=disabled, + ) + db_session.add(binding) + db_session.commit() + return binding + + def test_get_provider_auth_list_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + assert len(result) >= 1 + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert len(tenant_results) == 1 + assert tenant_results[0].provider == provider + + def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + def test_get_provider_auth_list_filters_disabled( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_success( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + mock_factory.assert_called_once() + mock_auth_instance.validate_credentials.assert_called_once() + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, "test_secret_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 1 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_validation_failed( + self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = False + mock_factory.return_value = mock_auth_instance + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 0 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encrypts_api_key( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + original_key = mock_args["credentials"]["config"]["api_key"] + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" + assert mock_args["credentials"]["config"]["api_key"] != original_key + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) + + def test_get_auth_credentials_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + ): + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=mock_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == mock_credentials + + def test_get_auth_credentials_not_found( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result is None + + def test_get_auth_credentials_json_parsing( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=special_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == special_credentials + assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" + + def test_delete_provider_auth_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + binding = self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider + ) + binding_id = binding.id + db_session_with_containers.expire_all() + + ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + + db_session_with_containers.expire_all() + remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() + assert remaining is None + + def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + # Should not raise when binding not found + ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + + def test_validate_api_key_auth_args_success(self, mock_args): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_category(self, mock_args): + del mock_args["category"] + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_category(self, mock_args): + mock_args["category"] = "" + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_provider(self, mock_args): + del mock_args["provider"] + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_provider(self, mock_args): + mock_args["provider"] = "" + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_credentials(self, mock_args): + del mock_args["credentials"] + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_credentials(self, mock_args): + mock_args["credentials"] = None + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_invalid_credentials_type(self, mock_args): + mock_args["credentials"] = "not_a_dict" + with pytest.raises(ValueError, match="credentials must be a dictionary"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_auth_type(self, mock_args): + del mock_args["credentials"]["auth_type"] + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = "" + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @pytest.mark.parametrize( + "malicious_input", + [ + "", + "'; DROP TABLE users; --", + "../../../etc/passwd", + "\\x00\\x00", + "A" * 10000, + ], + ) + def test_validate_api_key_auth_args_malicious_input(self, malicious_input, mock_args): + mock_args["category"] = malicious_input + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_database_error_handling( + self, mock_encrypter, mock_factory, flask_app_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_key" + + with patch("services.auth.api_key_auth_service.db.session") as mock_session: + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args): + mock_factory.side_effect = Exception("Factory error") + with pytest.raises(Exception, match="Factory error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, tenant_id, mock_args): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") + with pytest.raises(Exception, match="Encryption error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + def test_validate_api_key_auth_args_none_input(self): + with pytest.raises(TypeError): + ApiKeyAuthService.validate_api_key_auth_args(None) + + def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = ["api_key"] + ApiKeyAuthService.validate_api_key_auth_args(mock_args) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py new file mode 100644 index 0000000000..dc4c0fda1d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -0,0 +1,264 @@ +""" +API Key Authentication System Integration Tests +""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock, patch +from uuid import uuid4 + +import httpx +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_factory import ApiKeyAuthFactory +from services.auth.api_key_auth_service import ApiKeyAuthService +from services.auth.auth_type import AuthType + + +class TestAuthIntegration: + @pytest.fixture + def tenant_id_1(self) -> str: + return str(uuid4()) + + @pytest.fixture + def tenant_id_2(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def firecrawl_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} + + @pytest.fixture + def jina_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + def test_end_to_end_auth_flow( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_fc_test_key_123" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + mock_http.assert_called_once() + call_args = mock_http.call_args + assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] + assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" + + mock_encrypt.assert_called_once_with(tenant_id_1, "fc_test_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 1 + assert bindings[0].provider == AuthType.FIRECRAWL + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_cross_component_integration(self, mock_http, firecrawl_credentials): + mock_http.return_value = self._create_success_response() + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + result = factory.validate_credentials() + + assert result is True + mock_http.assert_called_once() + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.jina.jina.httpx.post") + def test_multi_tenant_isolation( + self, + mock_jina_http, + mock_fc_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + tenant_id_2, + category, + firecrawl_credentials, + jina_credentials, + ): + mock_fc_http.return_value = self._create_success_response() + mock_jina_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + + args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + + db_session_with_containers.expire_all() + + result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + + assert len(result1) == 1 + assert result1[0].tenant_id == tenant_id_1 + assert len(result2) == 1 + assert result2[0].tenant_id == tenant_id_2 + + def test_cross_tenant_access_prevention( + self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + + assert result is None + + def test_sensitive_data_protection(self): + credentials_with_secrets = { + "auth_type": "bearer", + "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, + } + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) + factory_str = str(factory) + + assert "super_secret_key_do_not_log" not in factory_str + assert "another_secret" not in factory_str + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token", return_value="encrypted_key") + def test_concurrent_creation_safety( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + app = flask_app_with_containers + mock_http.return_value = self._create_success_response() + + results = [] + exceptions = [] + + def create_auth(): + try: + with app.app_context(): + thread_args = { + "category": category, + "provider": AuthType.FIRECRAWL, + "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, + } + ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + results.append("success") + except Exception as e: + exceptions.append(e) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_auth) for _ in range(5)] + for future in futures: + future.result() + + assert len(results) == 5 + assert len(exceptions) == 0 + + @pytest.mark.parametrize( + "invalid_input", + [ + None, + {}, + {"auth_type": "bearer"}, + {"auth_type": "bearer", "config": {}}, + ], + ) + def test_invalid_input_boundary(self, invalid_input): + with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): + ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_http_error_handling(self, mock_http, firecrawl_credentials): + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = '{"error": "Unauthorized"}' + mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") + mock_http.return_value = mock_response + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + with pytest.raises((httpx.HTTPError, Exception)): + factory.validate_credentials() + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_network_failure_recovery( + self, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.side_effect = httpx.RequestError("Network timeout") + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + + with pytest.raises(httpx.RequestError): + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 0 + + @pytest.mark.parametrize( + ("provider", "credentials"), + [ + (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), + (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), + (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), + ], + ) + def test_all_providers_factory_creation(self, provider, credentials): + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class is not None + + factory = ApiKeyAuthFactory(provider, credentials) + assert factory.auth is not None + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_get_auth_credentials_returns_stored_credentials( + self, + mock_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + assert result is not None + assert result["config"]["api_key"] == "encrypted_key" + + def _create_success_response(self, status_code=200): + mock_response = Mock() + mock_response.status_code = status_code + mock_response.json.return_value = {"status": "success"} + mock_response.raise_for_status.return_value = None + return mock_response diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d7..02c3d1a80e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index f995ac7bef..42d587b7f7 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id document.indexing_status = indexing_status diff --git a/api/tests/test_containers_integration_tests/services/enterprise/__init__.py b/api/tests/test_containers_integration_tests/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 0000000000..4e8255d8ed --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,200 @@ +"""Integration tests for account deletion synchronization. + +Verifies enterprise account deletion sync functionality including +Redis queuing, error handling, and community vs enterprise behavior. +""" + +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from redis import RedisError + +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + def test_queue_task_success(self): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source="test_source") + + assert result is True + + import json + + raw = redis_client.rpop("enterprise:member:sync:queue") + assert raw is not None + task_data = json.loads(raw) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == "test_source" + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = RedisError("Connection failed") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = TypeError("Cannot serialize") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source="removed") + + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source="removed") + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is False + + +class TestSyncAccountDeletion: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is True + assert mock_queue_task.call_count == 3 + + queued_workspace_ids = {call.kwargs["workspace_id"] for call in mock_queue_task.call_args_list} + assert queued_workspace_ids == set(tenant_ids) + + def test_sync_account_deletion_no_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + fail_tenant = tenant_ids[1] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != fail_tenant + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_id = str(uuid4()) + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + mock_queue_task.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/plugin/__init__.py b/api/tests/test_containers_integration_tests/services/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index bfa9fe976b..3885137221 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -6,10 +6,13 @@ HIDDEN_VALUE replacement, and error handling for missing records. from __future__ import annotations +import json from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from models.tools import BuiltinToolProvider from services.plugin.plugin_parameter_service import PluginParameterService @@ -39,67 +42,73 @@ class TestGetDynamicSelectOptionsTool: @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_fetches_credentials_with_credential_id(self, mock_tool_mgr, mock_db, mock_encrypter_fn, mock_client_cls): + def test_fetches_credentials_with_credential_id( + self, + mock_tool_mgr, + mock_encrypter_fn, + mock_client_cls, + flask_app_with_containers, + db_session_with_containers, + ): + tenant_id = str(uuid4()) provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl encrypter = MagicMock() encrypter.decrypt.return_value = {"api_key": "decrypted"} mock_encrypter_fn.return_value = (encrypter, None) + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - # Mock the Session/query chain - db_record = MagicMock() - db_record.credentials = {"api_key": "encrypted"} - db_record.credential_type = "api_key" + db_record = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=str(uuid4()), + provider="google", + name="API KEY 1", + encrypted_credentials=json.dumps({"api_key": "encrypted"}), + credential_type="api_key", + ) + db_session_with_containers.add(db_record) + db_session_with_containers.commit() - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.first.return_value = db_record - mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - - result = PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id="cred-1", - provider_type="tool", - ) + result = PluginParameterService.get_dynamic_select_options( + tenant_id=tenant_id, + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=db_record.id, + provider_type="tool", + ) assert result == ["opt1"] @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_raises_when_tool_provider_not_found(self, mock_tool_mgr, mock_db, mock_encrypter_fn): + def test_raises_when_tool_provider_not_found( + self, + mock_tool_mgr, + mock_encrypter_fn, + flask_app_with_containers, + db_session_with_containers, + ): provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl mock_encrypter_fn.return_value = (MagicMock(), None) - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None - - with pytest.raises(ValueError, match="not found"): - PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id=None, - provider_type="tool", - ) + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id=str(uuid4()), + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) class TestGetDynamicSelectOptionsTrigger: diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 09b9ab498b..0cdae572fb 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -8,15 +8,27 @@ verification, marketplace upgrade flows, and uninstall with credential cleanup. from __future__ import annotations from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import select from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginVerification +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import PluginInstallationScope from services.plugin.plugin_service import PluginService -from tests.unit_tests.services.plugin.conftest import make_features + + +def _make_features( + restrict_to_marketplace: bool = False, + scope: PluginInstallationScope = PluginInstallationScope.ALL, +) -> MagicMock: + features = MagicMock() + features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace + features.plugin_installation_permission.plugin_installation_scope = scope + return features class TestFetchLatestPluginVersion: @@ -80,14 +92,14 @@ class TestFetchLatestPluginVersion: class TestCheckMarketplaceOnlyPermission: @patch("services.plugin.plugin_service.FeatureService") def test_raises_when_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_marketplace_only_permission() @patch("services.plugin.plugin_service.FeatureService") def test_passes_when_not_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) PluginService._check_marketplace_only_permission() # should not raise @@ -95,7 +107,7 @@ class TestCheckMarketplaceOnlyPermission: class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_allows_langgenius(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -103,14 +115,14 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_rejects_third_party(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_allows_partner(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) verification = MagicMock() @@ -120,7 +132,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_rejects_none(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) @@ -129,7 +141,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_none_scope_always_raises(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -138,7 +150,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_all_scope_passes_any(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) PluginService._check_plugin_installation_scope(None) # should not raise @@ -209,9 +221,9 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value - installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed + installer.fetch_plugin_manifest.return_value = MagicMock() installer.upgrade_plugin.return_value = MagicMock() PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") @@ -225,7 +237,7 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg-bytes" @@ -244,7 +256,7 @@ class TestUpgradePluginWithGithub: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.upgrade_plugin.return_value = MagicMock() @@ -259,7 +271,7 @@ class TestUploadPkg: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() upload_resp.verification = None mock_installer_cls.return_value.upload_pkg.return_value = upload_resp @@ -283,7 +295,7 @@ class TestInstallFromMarketplacePkg: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg" @@ -298,14 +310,14 @@ class TestInstallFromMarketplacePkg: assert result == "task-id" installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["resolved-uid"] # uses response uid, not input + assert call_args[1] == ["resolved-uid"] @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.dify_config") def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.return_value = MagicMock() decode_resp = MagicMock() @@ -317,7 +329,7 @@ class TestInstallFromMarketplacePkg: installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["uid-1"] # uses original uid + assert call_args[1] == ["uid-1"] class TestUninstall: @@ -332,26 +344,70 @@ class TestUninstall: assert result is True installer.uninstall.assert_called_once_with("t1", "install-1") - @patch("services.plugin.plugin_service.db") @patch("services.plugin.plugin_service.PluginInstaller") - def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db): + def test_cleans_credentials_when_plugin_found( + self, mock_installer_cls, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + plugin_id = "org/myplugin" + provider_name = f"{plugin_id}/model-provider" + + credential = ProviderCredential( + tenant_id=tenant_id, + provider_name=provider_name, + credential_name="default", + encrypted_config="{}", + ) + db_session_with_containers.add(credential) + db_session_with_containers.flush() + credential_id = credential.id + + provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + credential_id=credential_id, + ) + db_session_with_containers.add(provider) + db_session_with_containers.flush() + provider_id = provider.id + + pref = TenantPreferredModelProvider( + tenant_id=tenant_id, + provider_name=provider_name, + preferred_provider_type="custom", + ) + db_session_with_containers.add(pref) + db_session_with_containers.commit() + plugin = MagicMock() plugin.installation_id = "install-1" - plugin.plugin_id = "org/myplugin" + plugin.plugin_id = plugin_id installer = mock_installer_cls.return_value installer.list_plugins.return_value = [plugin] installer.uninstall.return_value = True - # Mock Session context manager - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.scalars.return_value.all.return_value = [] # no credentials found - - with patch("services.plugin.plugin_service.Session") as mock_session_cls: - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - - result = PluginService.uninstall("t1", "install-1") + with patch("services.plugin.plugin_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + result = PluginService.uninstall(tenant_id, "install-1") assert result is True installer.uninstall.assert_called_once() + + db_session_with_containers.expire_all() + + remaining_creds = db_session_with_containers.scalars( + select(ProviderCredential).where(ProviderCredential.id == credential_id) + ).all() + assert len(remaining_creds) == 0 + + updated_provider = db_session_with_containers.get(Provider, provider_id) + assert updated_provider is not None + assert updated_provider.credential_id is None + + remaining_prefs = db_session_with_containers.scalars( + select(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name == provider_name, + ) + ).all() + assert len(remaining_prefs) == 0 diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/__init__.py b/api/tests/test_containers_integration_tests/services/recommend_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 0000000000..2b842629a7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +from models.model import App, RecommendedApp, Site +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +def _create_app(db_session, *, tenant_id: str, is_public: bool = True) -> App: + app = App( + tenant_id=tenant_id, + name=f"app-{uuid4()}", + mode="chat", + enable_site=True, + enable_api=True, + is_public=is_public, + ) + app.id = str(uuid4()) + db_session.add(app) + db_session.commit() + return app + + +def _create_site(db_session, *, app_id: str) -> Site: + site = Site( + app_id=app_id, + title=f"site-{uuid4()}", + default_language="en-US", + customize_token_strategy="not_allow", + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + site.id = str(uuid4()) + db_session.add(site) + db_session.commit() + return site + + +def _create_recommended_app( + db_session, + *, + app_id: str, + category: str = "chat", + language: str = "en-US", + is_listed: bool = True, + position: int = 1, +) -> RecommendedApp: + rec = RecommendedApp( + app_id=app_id, + description={"en-US": "test"}, + copyright="copy", + privacy_policy="pp", + category=category, + language=language, + is_listed=is_listed, + position=position, + ) + rec.id = str(uuid4()) + db_session.add(rec) + db_session.commit() + return rec + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, category="writing") + + app2 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app2.id) + _create_recommended_app(db_session_with_containers, app_id=app2.id, category="assistant") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + assert app2.id in app_ids + assert "assistant" in result["categories"] + assert "writing" in result["categories"] + + def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, language="en-US") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + +class TestFetchRecommendedAppDetailFromDb: + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) + + assert result is None + + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + mock_dsl.export_dsl.return_value = "exported_yaml" + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is not None + assert result["id"] == app1.id + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index b51fbc3a42..00a2f9a59f 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -28,7 +28,7 @@ class TestAgentService: patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, - patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant", autospec=True) as mock_model_manager, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service @@ -841,7 +841,7 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from dify_graph.file import FileTransferMethod, FileType + from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 7ce7357b41..b8e022503f 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -525,3 +525,147 @@ class TestAPIBasedExtensionService: # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + + def test_save_extension_api_key_exactly_four_chars_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 4 characters should be rejected (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="1234", + ) + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_api_key_exactly_five_chars_accepted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 5 characters should be accepted (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="12345", + ) + + saved = APIBasedExtensionService.save(extension_data) + assert saved.id is not None + + def test_save_extension_requestor_constructor_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Exception raised by requestor constructor is wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config") + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_network_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Network exceptions during ping are wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError( + "network failure" + ) + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_update_duplicate_name_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Updating an existing extension to use another extension's name should fail.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + ext1 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Alpha", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + ext2 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Beta", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + # Try to rename ext2 to ext1's name + ext2.name = "Extension Alpha" + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(ext2) + + def test_get_all_returns_empty_for_different_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Extensions from one tenant should not be visible to another.""" + fake = Faker() + _, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + _, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant1 is not None + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant1.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + assert tenant2 is not None + result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + assert result == [] diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/test_containers_integration_tests/services/test_api_token_service.py similarity index 71% rename from api/tests/unit_tests/services/test_api_token_service.py rename to api/tests/test_containers_integration_tests/services/test_api_token_service.py index ad4de93b25..a2028d3ed3 100644 --- a/api/tests/unit_tests/services/test_api_token_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_token_service.py @@ -1,80 +1,63 @@ +from __future__ import annotations + from datetime import datetime -from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest from werkzeug.exceptions import Unauthorized import services.api_token_service as api_token_service_module +from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken -@pytest.fixture -def mock_db_session(): - """Fixture providing common DB session mocking for query_token_from_db tests.""" - fake_engine = MagicMock() - - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - with ( - patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), - patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, - patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, - patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, - ): - yield { - "session": session, - "mock_session_class": mock_session_class, - "mock_cache_set": mock_cache_set, - "mock_record_usage": mock_record_usage, - "fake_engine": fake_engine, - } - - class TestQueryTokenFromDb: - def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): - """Test DB lookup success path caches token and records usage.""" - # Arrange - auth_token = "token-123" - scope = "app" - api_token = MagicMock() + def test_should_return_api_token_and_cache_when_token_exists( + self, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + app_id = str(uuid4()) + token_value = f"app-test-{uuid4()}" - mock_db_session["session"].scalar.return_value = api_token + api_token = ApiToken() + api_token.id = str(uuid4()) + api_token.app_id = app_id + api_token.tenant_id = tenant_id + api_token.type = "app" + api_token.token = token_value + db_session_with_containers.add(api_token) + db_session_with_containers.commit() - # Act - result = api_token_service_module.query_token_from_db(auth_token, scope) + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + result = api_token_service_module.query_token_from_db(token_value, "app") - # Assert - assert result == api_token - mock_db_session["mock_session_class"].assert_called_once_with( - mock_db_session["fake_engine"], expire_on_commit=False - ) - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) - mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + assert result.id == api_token.id + assert result.token == token_value + mock_cache_set.assert_called_once() + mock_record_usage.assert_called_once_with(token_value, "app") - def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): - """Test DB lookup miss path caches null marker and raises Unauthorized.""" - # Arrange - auth_token = "missing-token" - scope = "app" + def test_should_cache_null_and_raise_unauthorized_when_token_not_found( + self, flask_app_with_containers, db_session_with_containers + ): + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(f"missing-{uuid4()}", "app") - mock_db_session["session"].scalar.return_value = None - - # Act / Assert - with pytest.raises(Unauthorized, match="Access token is invalid"): - api_token_service_module.query_token_from_db(auth_token, scope) - - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) - mock_db_session["mock_record_usage"].assert_not_called() + mock_cache_set.assert_called_once() + call_args = mock_cache_set.call_args[0] + assert call_args[2] is None # cached None + mock_record_usage.assert_not_called() class TestRecordTokenUsage: def test_should_write_active_key_with_iso_timestamp_and_ttl(self): - """Test record_token_usage writes usage timestamp with one-hour TTL.""" - # Arrange auth_token = "token-123" scope = "dataset" fixed_time = datetime(2026, 2, 24, 12, 0, 0) @@ -84,26 +67,18 @@ class TestRecordTokenUsage: patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), patch.object(api_token_service_module, "redis_client") as mock_redis, ): - # Act api_token_service_module.record_token_usage(auth_token, scope) - # Assert mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) def test_should_not_raise_when_redis_write_fails(self): - """Test record_token_usage swallows Redis errors.""" - # Arrange with patch.object(api_token_service_module, "redis_client") as mock_redis: mock_redis.set.side_effect = Exception("redis unavailable") - - # Act / Assert api_token_service_module.record_token_usage("token-123", "app") class TestFetchTokenWithSingleFlight: def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): - """Test single-flight returns cache when another request already populated it.""" - # Arrange auth_token = "token-123" scope = "app" cached_token = CachedApiToken( @@ -115,39 +90,26 @@ class TestFetchTokenWithSingleFlight: last_used_at=None, created_at=None, ) - lock = MagicMock() lock.acquire.return_value = True with ( patch.object(api_token_service_module, "redis_client") as mock_redis, - patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token), patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == cached_token - mock_redis.lock.assert_called_once_with( - f"api_token_query_lock:{scope}:{auth_token}", - timeout=10, - blocking_timeout=5, - ) lock.acquire.assert_called_once_with(blocking=True) lock.release.assert_called_once() - mock_cache_get.assert_called_once_with(auth_token, scope) mock_query_db.assert_not_called() def test_should_query_db_when_lock_acquired_and_cache_missed(self): - """Test single-flight queries DB when cache remains empty after lock acquisition.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = True @@ -157,22 +119,16 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_called_once() def test_should_query_db_directly_when_lock_not_acquired(self): - """Test lock timeout branch falls back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = False @@ -182,19 +138,14 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_cache_get.assert_not_called() mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_not_called() def test_should_reraise_unauthorized_from_db_query(self): - """Test Unauthorized from DB query is propagated unchanged.""" - # Arrange auth_token = "token-123" scope = "app" lock = MagicMock() @@ -210,20 +161,15 @@ class TestFetchTokenWithSingleFlight: ), ): mock_redis.lock.return_value = lock - - # Act / Assert with pytest.raises(Unauthorized, match="Access token is invalid"): api_token_service_module.fetch_token_with_single_flight(auth_token, scope) lock.release.assert_called_once() def test_should_fallback_to_db_query_when_lock_raises_exception(self): - """Test Redis lock errors fall back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.side_effect = RuntimeError("redis lock error") @@ -232,11 +178,8 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) @@ -244,8 +187,6 @@ class TestFetchTokenWithSingleFlight: class TestApiTokenCacheTenantBranches: @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): - """Test scoped delete removes cache key and tenant index membership.""" - # Arrange token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) @@ -261,18 +202,14 @@ class TestApiTokenCacheTenantBranches: mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: - # Act result = ApiTokenCache.delete(token, scope) - # Assert assert result is True mock_redis.delete.assert_called_once_with(cache_key) mock_remove_index.assert_called_once_with("tenant-1", cache_key) @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): - """Test tenant invalidation deletes indexed cache entries and index key.""" - # Arrange tenant_id = "tenant-1" index_key = ApiTokenCache._make_tenant_index_key(tenant_id) mock_redis.smembers.return_value = { @@ -280,10 +217,8 @@ class TestApiTokenCacheTenantBranches: b"api_token:any:token-2", } - # Act result = ApiTokenCache.invalidate_by_tenant(tenant_id) - # Assert assert result is True mock_redis.smembers.assert_called_once_with(index_key) mock_redis.delete.assert_any_call("api_token:app:token-1") @@ -293,7 +228,6 @@ class TestApiTokenCacheTenantBranches: class TestApiTokenCacheCoreBranches: def test_cached_api_token_repr_should_include_id_and_type(self): - """Test CachedApiToken __repr__ includes key identity fields.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -303,11 +237,9 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - assert repr(token) == "" def test_serialize_token_should_handle_cached_api_token_instances(self): - """Test serialization path when input is already a CachedApiToken.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -317,35 +249,25 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - serialized = ApiTokenCache._serialize_token(token) - assert isinstance(serialized, bytes) assert b'"id":"id-123"' in serialized - assert b'"token":"token-123"' in serialized def test_deserialize_token_should_return_none_for_null_markers(self): - """Test null cache marker deserializes to None.""" assert ApiTokenCache._deserialize_token("null") is None assert ApiTokenCache._deserialize_token(b"null") is None def test_deserialize_token_should_return_none_for_invalid_payload(self): - """Test invalid serialized payload returns None.""" assert ApiTokenCache._deserialize_token("not-json") is None @patch("services.api_token_service.redis_client") def test_get_should_return_none_on_cache_miss(self, mock_redis): - """Test cache miss branch in ApiTokenCache.get.""" mock_redis.get.return_value = None - result = ApiTokenCache.get("token-123", "app") - assert result is None - mock_redis.get.assert_called_once_with("api_token:app:token-123") @patch("services.api_token_service.redis_client") def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): - """Test cache hit branch in ApiTokenCache.get.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -356,48 +278,34 @@ class TestApiTokenCacheCoreBranches: created_at=None, ) mock_redis.get.return_value = token.model_dump_json().encode("utf-8") - result = ApiTokenCache.get("token-123", "app") - assert isinstance(result, CachedApiToken) assert result.id == "id-123" @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index update exits early for missing tenant id.""" ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") - mock_redis.sadd.assert_not_called() - mock_redis.expire.assert_not_called() @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): - """Test tenant index update handles Redis write errors gracefully.""" mock_redis.sadd.side_effect = Exception("redis down") - ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.sadd.assert_called_once() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index removal exits early for missing tenant id.""" ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") - mock_redis.srem.assert_not_called() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): - """Test tenant index removal handles Redis errors gracefully.""" mock_redis.srem.side_effect = Exception("redis down") - ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.srem.assert_called_once() @patch("services.api_token_service.redis_client") def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): - """Test set returns False when Redis setex fails.""" mock_redis.setex.side_effect = Exception("redis write failed") api_token = MagicMock() api_token.id = "id-123" @@ -407,60 +315,41 @@ class TestApiTokenCacheCoreBranches: api_token.token = "token-123" api_token.last_used_at = None api_token.created_at = None - result = ApiTokenCache.set("token-123", "app", api_token) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): - """Test delete(scope=None) returns False when scan_iter raises.""" mock_redis.scan_iter.side_effect = Exception("scan failed") - result = ApiTokenCache.delete("token-123", None) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): - """Test scoped delete still succeeds when tenant lookup from cache fails.""" token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) mock_redis.get.side_effect = Exception("get failed") - result = ApiTokenCache.delete(token, scope) - assert result is True mock_redis.delete.assert_called_once_with(cache_key) @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): - """Test scoped delete returns False when delete operation fails.""" - token = "token-123" - scope = "app" mock_redis.get.return_value = None mock_redis.delete.side_effect = Exception("delete failed") - - result = ApiTokenCache.delete(token, scope) - + result = ApiTokenCache.delete("token-123", "app") assert result is False @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): - """Test tenant invalidation returns True when tenant index is empty.""" mock_redis.smembers.return_value = set() - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is True mock_redis.delete.assert_not_called() @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): - """Test tenant invalidation returns False when Redis operation fails.""" mock_redis.smembers.side_effect = Exception("redis failed") - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is False diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 8a362e1f5e..33955d5d84 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -26,7 +26,7 @@ class TestAppDslService: patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..fa57dd4a6f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -23,7 +23,7 @@ class TestAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. @@ -1142,3 +1245,51 @@ class TestAppService: assert paginated_apps is not None assert paginated_apps.total == 1 assert all("50%" in app.name for app in paginated_apps.items) + + def test_get_app_code_by_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_code_by_id raises ValueError when site is missing.""" + from uuid import uuid4 + + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id(str(uuid4())) + + def test_get_app_id_by_code_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_id_by_code raises ValueError when code does not exist.""" + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("nonexistent-code") + + def test_get_app_meta_returns_empty_when_workflow_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when workflow is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + workflow_app = SimpleNamespace(mode="workflow", workflow=None) + + meta = app_service.get_app_meta(workflow_app) + assert meta == {"tool_icons": {}} + + def test_get_app_meta_returns_empty_when_model_config_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when app_model_config is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + chat_app = SimpleNamespace(mode="chat", app_model_config=None) + + meta = app_service.get_app_meta(chat_app) + assert meta == {"tool_icons": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py new file mode 100644 index 0000000000..768a8baee2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -0,0 +1,80 @@ +"""Testcontainers integration tests for AttachmentService.""" + +import base64 +from datetime import UTC, datetime +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +import services.attachment_service as attachment_service_module +from extensions.ext_database import db +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.attachment_service import AttachmentService + + +class TestAttachmentService: + def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id or str(uuid4()), + storage_type=StorageType.OPENDAL, + key=f"upload/{uuid4()}.txt", + name="test-file.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + def test_should_initialize_with_sessionmaker(self): + session_factory = sessionmaker() + + service = AttachmentService(session_factory=session_factory) + + assert service._session_maker is session_factory + + def test_should_initialize_with_engine(self): + engine = create_engine("sqlite:///:memory:") + + service = AttachmentService(session_factory=engine) + session = service._session_maker() + try: + assert session.bind == engine + finally: + session.close() + engine.dispose() + + @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) + def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + AttachmentService(session_factory=invalid_session_factory) + + def test_should_return_base64_when_file_exists(self, db_session_with_containers): + upload_file = self._create_upload_file(db_session_with_containers) + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: + result = service.get_file_base64(upload_file.id) + + assert result == base64.b64encode(b"binary-content").decode() + mock_load.assert_called_once_with(upload_file.key) + + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once") as mock_load: + with pytest.raises(NotFound, match="File not found"): + service.get_file_base64(str(uuid4())) + + mock_load.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py new file mode 100644 index 0000000000..02ab3f8314 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -0,0 +1,58 @@ +"""Testcontainers integration tests for ConversationVariableUpdater.""" + +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import sessionmaker + +from extensions.ext_database import db +from graphon.variables import StringVariable +from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater + + +class TestConversationVariableUpdater: + def _create_conversation_variable( + self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + ) -> ConversationVariable: + row = ConversationVariable( + id=variable.id, + conversation_id=conversation_id, + app_id=app_id or str(uuid4()), + data=variable.model_dump_json(), + ) + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="old value") + self._create_conversation_variable( + db_session_with_containers, conversation_id=conversation_id, variable=variable + ) + + updated_variable = StringVariable(id=variable.id, name="topic", value="new value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + updater.update(conversation_id=conversation_id, variable=updated_variable) + + db_session_with_containers.expire_all() + row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id)) + assert row is not None + assert row.data == updated_variable.model_dump_json() + + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): + updater.update(conversation_id=conversation_id, variable=variable) + + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + result = updater.flush() + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..0f63d98642 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -0,0 +1,104 @@ +"""Testcontainers integration tests for CreditPoolService.""" + +from uuid import uuid4 + +import pytest + +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from models.enums import ProviderQuotaType +from services.credit_pool_service import CreditPoolService + + +class TestCreditPoolService: + def _create_tenant_id(self) -> str: + return str(uuid4()) + + def test_create_default_pool(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + + pool = CreditPoolService.create_default_pool(tenant_id) + + assert isinstance(pool, TenantCreditPool) + assert pool.tenant_id == tenant_id + assert pool.pool_type == ProviderQuotaType.TRIAL + assert pool.quota_used == 0 + assert pool.quota_limit > 0 + + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) + + assert result is not None + assert result.tenant_id == tenant_id + assert result.pool_type == ProviderQuotaType.TRIAL + + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) + + assert result is None + + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) + + assert result is False + + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10) + + assert result is True + + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + # Exhaust credits + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1) + + assert result is False + + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + with pytest.raises(QuotaExceededError, match="Credit pool not found"): + CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) + + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + with pytest.raises(QuotaExceededError, match="No credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) + + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + credits_required = 10 + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required) + + assert result == credits_required + db_session_with_containers.expire_all() + pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert pool.quota_used == credits_required + + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + remaining = 5 + pool.quota_used = pool.quota_limit - remaining + db_session_with_containers.commit() + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + + assert result == remaining + db_session_with_containers.expire_all() + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool.quota_used == pool.quota_limit diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 975af3d428..71c8874f79 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", @@ -397,6 +398,68 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + """Test that users from different tenants cannot access dataset.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other_user) + + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + """Test that tenant owners can access any dataset regardless of permission level.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, owner) + + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + """Test ONLY_ME permission allows only the dataset creator to access.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + """Test ONLY_ME permission denies access to non-creators.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other) + + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + """Test ALL_TEAM permission allows any team member to access the dataset.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, member) + def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): """ Test that user with explicit permission can access partial_members dataset. @@ -443,6 +506,16 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, creator) + class TestDatasetServiceCheckDatasetOperatorPermission: """Verify operator permission checks against persisted partial-member permissions.""" diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604..0de3c64c4f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,8 +11,9 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus @@ -62,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory: created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -156,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -173,20 +174,20 @@ class TestDatasetServiceCreateDataset: embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act - with patch("services.dataset_service.ModelManager") as mock_model_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model result = DatasetService.create_empty_dataset( tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -263,7 +264,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, ): mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model @@ -272,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -296,7 +297,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, ): mock_model_manager.return_value.get_model_instance.return_value = embedding_model @@ -305,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -313,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -588,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -684,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, @@ -706,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] + + +class TestDocumentServicePauseRecoverRetry: + """Tests for pause/recover/retry orchestration using real DB and Redis.""" + + def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc = factory.create_document(db_session_with_containers, dataset, account.id) + doc.indexing_status = indexing_status + db_session_with_containers.commit() + return doc, account + + def test_pause_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is True + assert doc.paused_by == account.id + assert doc.paused_at is not None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is not None + redis_client.delete(cache_key) + + def test_pause_document_invalid_status_error(self, db_session_with_containers): + from services.dataset_service import DocumentService + from services.errors.document import DocumentIndexingError + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(doc) + + def test_recover_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + # Pause first + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + # Recover + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is False + assert doc.paused_by is None + assert doc.paused_at is None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is None + recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) + + def test_retry_document_indexing_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt") + doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt") + doc2.position = 2 + doc1.indexing_status = "error" + doc2.indexing_status = "error" + db_session_with_containers.commit() + + with ( + patch("services.dataset_service.current_user") as mock_user, + patch("services.dataset_service.retry_document_indexing_task") as retry_task, + ): + mock_user.id = account.id + DocumentService.retry_document(dataset.id, [doc1, doc2]) + + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + assert doc1.indexing_status == "waiting" + assert doc2.indexing_status == "waiting" + + # Verify redis keys were set + assert redis_client.get(f"document_{doc1.id}_is_retried") is not None + assert redis_client.get(f"document_{doc2.id}_is_retried") is not None + retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id) + + # Cleanup + redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried") diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 7983b1cd93..c1d088755c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id or str(uuid4()) document.enabled = enabled @@ -694,3 +695,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) + + def test_batch_update_invalid_action_raises_value_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Test that an invalid action raises ValueError.""" + factory = DocumentBatchUpdateIntegrationDataFactory + dataset = factory.create_dataset(db_session_with_containers) + doc = factory.create_document(db_session_with_containers, dataset) + user = UserDouble(id=str(uuid4())) + + patched_dependencies["redis_client"].get.return_value = None + + with pytest.raises(ValueError, match="Invalid action"): + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py new file mode 100644 index 0000000000..c486ff5613 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -0,0 +1,60 @@ +"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.account import Account, Tenant, TenantAccountJoin +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity + + +class TestDatasetServiceCreateRagPipelineDataset: + def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"ds_create_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return tenant, account + + def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity: + icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji") + return RagPipelineDatasetCreateEntity( + name=name, + description="", + icon_info=icon_info, + permission="only_me", + ) + + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + tenant, _ = self._create_tenant_and_account(db_session_with_containers) + + mock_user = Mock(id=None) + with patch("services.dataset_service.current_user", mock_user): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=self._build_entity(), + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index ed070527c9..3cac964d89 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory: tenant_id: str, dataset_id: str, created_by: str, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """Persist a document so dataset.doc_form resolves through the real document path.""" document = Document( @@ -108,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset: tenant_id=tenant.id, dataset_id=dataset.id, created_by=owner.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Act @@ -207,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2..87239b2cb3 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d..2f90d16176 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..883c3c3feb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -4,7 +4,8 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -362,7 +363,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -457,7 +458,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", @@ -543,7 +544,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 5f86cb2ae9..fe426ae516 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -7,7 +7,7 @@ from uuid import uuid4 from sqlalchemy import select -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion @@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None + + def test_delete_run_dry_run(self, db_session_with_containers): + """Dry run should return success without actually deleting.""" + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + run_id = run.id + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + + result = deleter._delete_run(run) + + assert result.success is True + assert result.run_id == run_id + # Run should still exist because it's a dry run + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is not None + + def test_delete_run_exception_returns_error(self, db_session_with_containers): + """Exception during deletion should return failure result.""" + from unittest.mock import MagicMock, patch + + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + deleter = ArchivedWorkflowRunDeletion(dry_run=False) + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deleter._delete_run(run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_by_run_id_success(self, db_session_with_containers): + """Successfully delete an archived workflow run by ID.""" + tenant_id = str(uuid4()) + base_time = datetime.now(UTC) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=base_time, + ) + self._create_archive_log(db_session_with_containers, run=run) + run_id = run.id + + deleter = ArchivedWorkflowRunDeletion() + result = deleter.delete_by_run_id(run_id) + + assert result.success is True + db_session_with_containers.expunge_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is None + + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + """_get_workflow_run_repo should return a cached repo on subsequent calls.""" + deleter = ArchivedWorkflowRunDeletion() + + repo1 = deleter._get_workflow_run_repo() + repo2 = deleter._get_workflow_run_repo() + + assert repo1 is repo2 + assert deleter.workflow_run_repo is repo1 diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index c6aa89c733..c0047df810 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -3,6 +3,7 @@ from uuid import uuid4 from sqlalchemy import select +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -42,7 +43,7 @@ def _create_document( name=f"doc-{uuid4()}", created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = str(uuid4()) document.indexing_status = indexing_status @@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c rows = db_session_with_containers.scalars(filtered).all() assert {row.id for row in rows} == {doc1.id, doc2.id} + + +def test_normalize_display_status_alias_mapping(): + """Test that normalize_display_status maps aliases correctly.""" + assert DocumentService.normalize_display_status("ACTIVE") == "available" + assert DocumentService.normalize_display_status("enabled") == "available" + assert DocumentService.normalize_display_status("archived") == "archived" + assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index bffa520ce6..34532ed7f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -69,7 +70,7 @@ def make_document( name=name, created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) doc.id = document_id doc.indexing_status = "completed" diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index ae811db768..cafabc939b 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById: ) assert result is None + + +class TestEndUserServiceCreateBatch: + """Integration tests for EndUserService.create_end_user_batch.""" + + @pytest.fixture + def factory(self): + return TestEndUserServiceFactory() + + def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + """Create multiple apps under the same tenant.""" + first_app = factory.create_app_and_account(db_session_with_containers) + tenant_id = first_app.tenant_id + apps = [first_app] + for _ in range(count - 1): + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=first_app.created_by, + updated_by=first_app.updated_by, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() + return tenant_id, all_apps + + def test_create_batch_empty_app_ids(self, db_session_with_containers): + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" + ) + assert result == {} + + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 3 + for app_id in app_ids: + assert app_id in result + assert result[app_id].session_id == user_id + assert result[app_id].type == InvokeFrom.SERVICE_API + + def test_create_batch_default_session_id(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="" + ) + + assert len(result) == 2 + for end_user in result.values(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 2 + + def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + # Create batch first time + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Create batch second time — should return existing users + second_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(second_result) == 2 + for app_id in app_ids: + assert first_result[app_id].id == second_result[app_id].id + + def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + user_id = f"user-{uuid4()}" + + # Create for first 2 apps + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[apps[0].id, apps[1].id], + user_id=user_id, + ) + + # Create for all 3 apps — should reuse first 2, create 3rd + all_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[a.id for a in apps], + user_id=user_id, + ) + + assert len(all_result) == 3 + assert all_result[apps[0].id].id == first_result[apps[0].id].id + assert all_result[apps[1].id].id == first_result[apps[1].id].id + assert all_result[apps[2].id].session_id == user_id + + @pytest.mark.parametrize( + "invoke_type", + [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], + ) + def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id + ) + + assert len(result) == 1 + assert result[apps[0].id].type == invoke_type diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py new file mode 100644 index 0000000000..4e0a726cc7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -0,0 +1,96 @@ +""" +Testcontainers integration tests for FileService helpers. + +Covers: +- ZIP tempfile building (sanitization + deduplication + content writes) +- tenant-scoped batch lookup behavior (get_upload_files_by_ids) +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 +from zipfile import ZipFile + +import pytest + +import services.file_service as file_service_module +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.file_service import FileService + + +def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.OPENDAL, + key=key, + name=name, + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session.add(upload_file) + db_session.commit() + return upload_file + + +def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure ZIP entry names are safe and unique while preserving extensions.""" + upload_files: list[Any] = [ + SimpleNamespace(name="a/b.txt", key="k1"), + SimpleNamespace(name="c/b.txt", key="k2"), + SimpleNamespace(name="../b.txt", key="k3"), + ] + + data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} + + def _load(key: str, stream: bool = True) -> list[bytes]: + assert stream is True + return data_by_key[key] + + monkeypatch.setattr(file_service_module.storage, "load", _load) + + with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: + with ZipFile(tmp, mode="r") as zf: + assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] + assert zf.read("b.txt") == b"one" + assert zf.read("b (1).txt") == b"two" + assert zf.read("b (2).txt") == b"three" + + +def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None: + """Ensure empty input returns an empty mapping without hitting the database.""" + assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + + +def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None: + """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" + tenant_id = str(uuid4()) + file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") + file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") + + result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + + assert set(result.keys()) == {file1.id, file2.id} + assert result[file1.id].id == file1.id + assert result[file2.id].id == file2.id + + +def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None: + """Ensure files from other tenants are not returned.""" + tenant_a = str(uuid4()) + tenant_b = str(uuid4()) + file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") + _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") + + result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + + assert set(result.keys()) == {file_a.id} diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 70d05792ce..18c5320d0a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,14 +4,14 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType @@ -54,7 +54,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="recipient@example.com")], ), subject="Test {{recipient_email}}", diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py similarity index 79% rename from api/tests/unit_tests/services/test_human_input_delivery_test_service.py rename to api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index a23c44b26e..21a54e909e 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -1,18 +1,22 @@ +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest from sqlalchemy.engine import Engine from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool +from graphon.runtime import VariablePool +from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, @@ -28,15 +32,12 @@ from services.human_input_delivery_test_service import ( ) -@pytest.fixture -def mock_db(monkeypatch): - mock_db = MagicMock() - monkeypatch.setattr(service_module, "db", mock_db) - return mock_db - - def _make_valid_email_config(): - return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + return EmailDeliveryConfig( + recipients=EmailRecipients(include_bound_group=False, items=[]), + subject="Subj", + body="Body", + ) def test_build_form_link(): @@ -87,7 +88,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, mock_db): + def test_default(self, flask_app_with_containers, db_session_with_containers): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -246,62 +247,70 @@ class TestEmailDeliveryTestHandler: _, kwargs = mock_mail_send.call_args assert kwargs["subject"] == "Notice BCC:test@example.com" - def test_resolve_recipients(self): + def test_resolve_recipients_external(self): handler = EmailDeliveryTestHandler(session_factory=MagicMock()) - - # Test Case 1: External Recipient method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + recipients=EmailRecipients( + items=[ExternalRecipient(email="ext@example.com")], include_bound_group=False + ), subject="", body="", ) ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - # Test Case 2: Member Recipient + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account = Account(name="Test User", email="member@example.com") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + recipients=EmailRecipients(items=[MemberRecipient(reference_id=account.id)], include_bound_group=False), subject="", body="", ) ) - handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) - assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] + assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - # Test Case 3: Whole Workspace + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") + account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") + db_session_with_containers.add_all([account1, account2]) + db_session_with_containers.commit() + + for acc in [account1, account2]: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=acc.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( - config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[], include_bound_group=True), + subject="", + body="", + ) ) - handler._query_workspace_member_emails = MagicMock( - return_value={"u1": "u1@example.com", "u2": "u2@example.com"} - ) - recipients = handler._resolve_recipients(tenant_id="t1", method=method) - assert set(recipients) == {"u1@example.com", "u2@example.com"} + recipients = handler._resolve_recipients(tenant_id=tenant_id, method=method) + assert set(recipients) == {account1.email, account2.email} - def test_query_workspace_member_emails(self): - mock_session = MagicMock() - mock_session_factory = MagicMock(return_value=mock_session) - mock_session.__enter__.return_value = mock_session - - handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) - - # Empty user_ids + def test_query_workspace_member_emails_empty_ids(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} - # user_ids is None (all) - mock_execute = MagicMock() - mock_session.execute.return_value = mock_execute - mock_execute.all.return_value = [("u1", "u1@example.com")] - - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) - assert result == {"u1": "u1@example.com"} - - # user_ids with values - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) - assert result == {"u1": "u1@example.com"} - def test_build_substitutions(self): context = DeliveryTestContext( tenant_id="t1", @@ -313,7 +322,8 @@ class TestEmailDeliveryTestHandler: recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], ) - subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") assert subs["node_title"] == "title" assert subs["form_content"] == "content" @@ -322,7 +332,6 @@ class TestEmailDeliveryTestHandler: assert subs["form_token"] == "token123" assert "form/token123" in subs["form_link"] - # Without matching recipient subs_no_match = EmailDeliveryTestHandler._build_substitutions( context=context, recipient_email="other@example.com" ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 85dc04b162..bdf6d9b951 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -25,7 +25,7 @@ class TestMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.ModelManager.for_tenant") as mock_model_manager, patch("services.message_service.WorkflowService") as mock_workflow_service, patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, patch("services.message_service.LLMGenerator") as mock_llm_generator, diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 8707f2e827..c0c1c25f1e 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from graphon.file.enums import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, @@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", belongs_to=MessageFileBelongsTo.USER, diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py new file mode 100644 index 0000000000..b55a19eaa9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +from models.dataset import Dataset, DatasetMetadataBinding, Document +from models.enums import DataSourceType, DocumentCreatedFrom +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def _create_dataset(db_session, *, tenant_id: str, built_in_field_enabled: bool = False) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=str(uuid4()), + ) + dataset.id = str(uuid4()) + dataset.built_in_field_enabled = built_in_field_enabled + db_session.add(dataset) + db_session.commit() + return dataset + + +def _create_document(db_session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None) -> Document: + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info="{}", + batch=f"batch-{uuid4()}", + name=f"doc-{uuid4()}", + created_from=DocumentCreatedFrom.WEB, + created_by=str(uuid4()), + ) + document.id = str(uuid4()) + document.doc_metadata = doc_metadata + db_session.add(document) + db_session.commit() + return document + + +class TestMetadataPartialUpdate: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def user_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def mock_current_account(self, user_id, tenant_id): + account = Mock(id=user_id, current_tenant_id=tenant_id) + with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)): + yield account + + def test_partial_update_merges_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata["existing_key"] == "existing_value" + assert updated_doc.doc_metadata["new_key"] == "new_value" + + def test_full_update_replaces_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata == {"new_key": "new_value"} + assert "existing_key" not in updated_doc.doc_metadata + + def test_partial_update_skips_existing_binding( + self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + existing_binding = DatasetMetadataBinding( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=document.id, + metadata_id=meta_id, + created_by=user_id, + ) + db_session_with_containers.add(existing_binding) + db_session_with_containers.commit() + + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="existing_key", value="existing_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where( + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == meta_id, + ) + ).all() + assert len(bindings) == 1 + + def test_rollback_called_on_commit_failure( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="key", value="value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): + with pytest.raises(RuntimeError, match="database connection lost"): + MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e847329c5b..8b1349be9a 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -5,6 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom @@ -139,7 +140,7 @@ class TestMetadataService: name=fake.file_name(), created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 989df42499..ca6e7afeab 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -18,11 +18,10 @@ class TestModelLoadBalancingService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch( - "services.model_load_balancing_service.ModelProviderFactory", autospec=True - ) as mock_model_provider_factory, + "services.model_load_balancing_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns @@ -46,9 +45,6 @@ class TestModelLoadBalancingService: # Mock LBModelManager mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) - # Mock ModelProviderFactory - mock_model_provider_factory_instance = mock_model_provider_factory.return_value - # Mock credential schemas mock_credential_schema = MagicMock() mock_credential_schema.credential_form_schemas = [] @@ -61,7 +57,6 @@ class TestModelLoadBalancingService: yield { "provider_manager": mock_provider_manager, "lb_model_manager": mock_lb_model_manager, - "model_provider_factory": mock_model_provider_factory, "encrypter": mock_encrypter, "provider_config": mock_provider_config, "provider_model_setting": mock_provider_model_setting, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 6afc5aa43c..8955a3b5f2 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -5,7 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -18,8 +18,12 @@ class TestModelProviderService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory, + patch( + "services.model_provider_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch( + "services.model_provider_service.create_plugin_model_provider_factory", autospec=True + ) as mock_model_provider_factory, ): # Setup default mock returns mock_provider_manager.return_value.get_configurations.return_value = MagicMock() @@ -402,8 +406,8 @@ class TestModelProviderService: # Create mock models from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject - from dify_graph.model_runtime.entities.provider_entities import ProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject + from graphon.model_runtime.entities.provider_entities import ProviderEntity # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( @@ -640,7 +644,7 @@ class TestModelProviderService: # Create mock default model response from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..c146a5924b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index ba4310e22e..7036524918 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -2,17 +2,43 @@ Testcontainers integration tests for workflow run restore functionality. """ +from __future__ import annotations + +from datetime import datetime from uuid import uuid4 from sqlalchemy import select -from models.workflow import WorkflowPause +from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore class TestWorkflowRunRestore: """Tests for the WorkflowRunRestore class.""" + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + def test_restore_table_records_returns_rowcount(self, db_session_with_containers): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 94a4e62560..70aa813142 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -20,7 +20,7 @@ class TestSavedMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.saved_message_service.MessageService") as mock_message_service, ): # Setup default mock returns @@ -396,11 +396,6 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) - # Verify no database operations were performed - - saved_messages = db_session_with_containers.query(SavedMessage).all() - assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -497,124 +492,140 @@ class TestSavedMessageService: # The message should still exist, only the saved_message should be deleted assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): - """ - Test error handling when no user is provided. - - This test verifies: - - Proper error handling for missing user - - ValueError is raised when user is None - - No database operations are performed - """ - # Arrange: Create test data - fake = Faker() + def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test saving a message for an EndUser.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) - # Act & Assert: Verify proper error handling - with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + mock_external_service_dependencies["message_service"].get_message.return_value = message - assert "User is required" in str(exc_info.value) + SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) - # Verify no database operations were performed for this specific test - # Note: We don't check total count as other tests may have created data - # Instead, we verify that the error was properly raised - pass - - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): - """ - Test error handling when saving message with no user. - - This test verifies: - - Method returns early when user is None - - No database operations are performed - - No exceptions are raised - """ - # Arrange: Create test data - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - message = self._create_test_message(db_session_with_containers, app, account) - - # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) - - # Assert: Verify the expected outcomes - assert result is None - - # Verify no saved message was created - - saved_message = ( + saved = ( db_session_with_containers.query(SavedMessage) - .where( - SavedMessage.app_id == app.id, - SavedMessage.message_id == message.id, - ) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) .first() ) + assert saved is not None + assert saved.created_by == end_user.id + assert saved.created_by_role == "end_user" - assert saved_message is None - - def test_delete_success_existing_message( + def test_save_duplicate_is_idempotent( self, db_session_with_containers: Session, mock_external_service_dependencies ): - """ - Test successful deletion of an existing saved message. - - This test verifies: - - Proper deletion of existing saved message - - Correct database state after deletion - - No errors during deletion process - """ - # Arrange: Create test data - fake = Faker() + """Test that saving an already-saved message does not create a duplicate.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) message = self._create_test_message(db_session_with_containers, app, account) - # Create a saved message first - saved_message = SavedMessage( - app_id=app.id, - message_id=message.id, - created_by_role="account", - created_by=account.id, - ) + mock_external_service_dependencies["message_service"].get_message.return_value = message - db_session_with_containers.add(saved_message) + # Save once + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + # Save again + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + + count = ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .count() + ) + assert count == 1 + + def test_delete_without_user_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting without a user is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Pre-create a saved message + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + db_session_with_containers.add(saved) db_session_with_containers.commit() - # Verify saved message exists + SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + + # Should still exist + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is not None + ) + + def test_delete_non_existent_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting a non-existent saved message is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Should not raise — use a valid UUID that doesn't exist in DB + from uuid import uuid4 + + SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + + def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test deleting a saved message for an EndUser.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) + + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + db_session_with_containers.add(saved) + db_session_with_containers.commit() + + SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is None + ) + + def test_delete_only_affects_own_saved_messages( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that delete only removes the requesting user's saved message.""" + app, account1 = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, account1) + + # Both users save the same message + saved_account = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + ) + saved_end_user = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + ) + db_session_with_containers.add_all([saved_account, saved_end_user]) + db_session_with_containers.commit() + + # Delete only account1's saved message + SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + + # Account's saved message should be gone assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == account1.id, ) .first() - is not None + is None ) - - # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) - - # Assert: Verify the expected outcomes - # Check if saved message was deleted from database - deleted_saved_message = ( + # End user's saved message should still exist + assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == end_user.id, ) .first() + is not None ) - - assert deleted_saved_message is None - - # Verify database state - db_session_with_containers.commit() - # The message should still exist, only the saved_message should be deleted - assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index fa6e651529..f504f35589 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,9 +7,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset -from models.enums import DataSourceType +from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) @@ -547,7 +548,7 @@ class TestTagService: assert result is not None assert len(result) == 1 assert result[0].name == "python_tag" - assert result[0].type == "app" + assert result[0].type == TagType.APP assert result[0].tenant_id == tenant.id def test_get_tag_by_tag_name_no_matches( @@ -638,7 +639,7 @@ class TestTagService: # Verify all tags are returned for tag in result: - assert tag.type == "app" + assert tag.type == TagType.APP assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6b95954480..f2307fbd7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -25,7 +25,7 @@ class TestWebConversationService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a..2a18345c87 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1,20 +1,24 @@ +from __future__ import annotations + import json import uuid from datetime import UTC, datetime, timedelta +from types import SimpleNamespace from unittest.mock import patch import pytest from faker import Faker from sqlalchemy.orm import Session -from dify_graph.entities.workflow_execution import WorkflowExecutionStatus -from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun -from models.enums import CreatorUserRole +from graphon.entities.workflow_execution import WorkflowExecutionStatus +from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService -from services.workflow_app_service import WorkflowAppService +from services.workflow_app_service import LogView, WorkflowAppService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -27,7 +31,7 @@ class TestWorkflowAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -221,7 +225,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -357,7 +361,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_1.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -399,7 +403,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_2.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -441,7 +445,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_4.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -521,7 +525,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -627,7 +631,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -732,7 +736,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -860,7 +864,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -902,7 +906,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="web-app", + created_from=WorkflowAppLogCreatedFrom.WEB_APP, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -1037,7 +1041,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1125,7 +1129,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1279,7 +1283,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1379,7 +1383,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1481,7 +1485,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1524,3 +1528,168 @@ class TestWorkflowAppService: # Should not find tenant2's data when searching from tenant1's context assert result_cross_tenant["total"] == 0 + + def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"): + service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + ) + + def test_get_paginate_workflow_app_logs_filters_by_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account) + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account=account.email, + ) + + assert result["total"] >= 0 + assert isinstance(result["data"], list) + + def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type="browser", + is_anonymous=False, + session_id="session-1", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + + now = datetime.now(UTC) + archive_defaults = { + "workflow_id": str(uuid.uuid4()), + "run_version": "1.0.0", + "run_status": WorkflowExecutionStatus.SUCCEEDED, + "run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN, + "run_error": None, + "run_elapsed_time": 1.0, + "run_total_tokens": 0, + "run_total_steps": 0, + "run_created_at": now, + "run_finished_at": now, + "run_exceptions_count": 0, + "trigger_metadata": '{"type":"trigger-webhook"}', + "log_created_at": now, + "log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API, + } + archive_account = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=account.id, + created_by_role=CreatorUserRole.ACCOUNT, + **archive_defaults, + ) + archive_end_user = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=end_user.id, + created_by_role=CreatorUserRole.END_USER, + **archive_defaults, + ) + db_session_with_containers.add_all([archive_account, archive_end_user]) + db_session_with_containers.commit() + + result = service.get_paginate_workflow_archive_logs( + session=db_session_with_containers, + app_model=app, + page=1, + limit=20, + ) + + assert result["total"] == 2 + assert len(result["data"]) == 2 + account_item = next(d for d in result["data"] if d["created_by_account"] is not None) + end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None) + assert account_item["created_by_account"].id == account.id + assert end_user_item["created_by_end_user"].id == end_user.id + + +class TestLogView: + def test_details_and_proxy_attributes(self): + log = SimpleNamespace(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + assert view.details == {"trigger_metadata": {"type": "plugin"}} + assert view.status == "succeeded" + + +class TestHandleTriggerMetadata: + def test_returns_empty_dict_when_metadata_missing(self): + service = WorkflowAppService() + assert service.handle_trigger_metadata("tenant-1", None) == {} + + def test_enriches_plugin_icons(self): + service = WorkflowAppService() + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + with patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + def test_non_plugin_metadata_without_icon_lookup(self): + service = WorkflowAppService() + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +class TestSafeJsonLoads: + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], + ) + def test_handles_various_inputs(self, value, expected): + assert WorkflowAppService._safe_json_loads(value) == expected + + +class TestSafeParseUuid: + def test_returns_none_for_short_or_invalid_values(self): + service = WorkflowAppService() + assert service._safe_parse_uuid("short") is None + assert service._safe_parse_uuid("x" * 40) is None + + def test_returns_uuid_for_valid_string(self): + service = WorkflowAppService() + raw = str(uuid.uuid4()) + result = service._safe_parse_uuid(raw) + assert result is not None + assert str(result) == raw diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 572cf72fa0..86cf2327c7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -2,8 +2,8 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.segments import StringSegment +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -482,7 +482,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -734,7 +734,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 731770e01a..d02a078281 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -27,7 +27,7 @@ class TestWorkflowRunService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index a5fe052206..ee7b68e6aa 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -1503,10 +1503,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunSucceededEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1548,12 +1548,12 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from dify_graph.enums import BuiltinNodeTypes + from graphon.enums import BuiltinNodeTypes assert result.node_type == BuiltinNodeTypes.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs is not None @@ -1578,10 +1578,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1623,7 +1623,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error is not None @@ -1647,10 +1647,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) @@ -1693,7 +1693,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.EXCEPTION # Should be EXCEPTION, not FAILED assert result.outputs is not None diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a..d3e765055a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 0f38218c51..2dc50cc720 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -1,12 +1,24 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest from faker import Faker from sqlalchemy.orm import Session -from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolDescription, + ToolEntity, + ToolIdentity, + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -52,7 +64,7 @@ class TestToolTransformService: user_id="test_user_id", credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) elif provider_type == "builtin": @@ -659,7 +671,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -695,7 +707,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -731,7 +743,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -786,3 +798,192 @@ class TestToolTransformService: assert result is not None assert result == mock_controller mock_from_db.assert_called_once_with(provider) + + +def _mock_tool(*, base_params, runtime_params): + """Helper to build a Mock tool with real entity objects. + + Tool is abstract and requires runtime behaviour (fork_tool_runtime, + get_runtime_parameters), so it stays as a Mock. Everything else uses + real Pydantic instances. + """ + entity = ToolEntity( + identity=ToolIdentity( + author="test_author", + name="test_tool", + label=I18nObject(en_US="Test Tool"), + provider="test_provider", + ), + parameters=base_params or [], + description=ToolDescription( + human=I18nObject(en_US="Test description"), + llm="Test description for LLM", + ), + output_schema={}, + ) + mock_tool = Mock(spec=Tool) + mock_tool.entity = entity + mock_tool.get_runtime_parameters.return_value = runtime_params + mock_tool.fork_tool_runtime.return_value = mock_tool + return mock_tool + + +def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None): + return ToolParameter( + name=name, + label=I18nObject(en_US=label or name), + human_description=I18nObject(en_US=name), + type=ToolParameter.ToolParameterType.STRING, + form=form, + ) + + +class TestConvertToolEntityToApiEntity: + """Tests for ToolTransformService.convert_tool_entity_to_api_entity.""" + + def test_parameter_override(self): + base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")] + runtime = [_param("param1", label="Runtime 1")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 2 + assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1" + assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2" + + def test_additional_runtime_parameters(self): + base = [_param("param1", label="Base 1")] + runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 2 + names = [p.name for p in result.parameters] + assert "param1" in names + assert "runtime_only" in names + + def test_non_form_runtime_parameters_excluded(self): + base = [_param("param1")] + runtime = [ + _param("param1", label="Runtime 1"), + _param("llm_param", form=ToolParameter.ToolParameterForm.LLM), + ] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 1 + assert result.parameters[0].name == "param1" + + def test_empty_parameters(self): + tool = _mock_tool(base_params=[], runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_none_parameters(self): + tool = _mock_tool(base_params=None, runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_parameter_order_preserved(self): + base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")] + runtime = [_param("p2", label="R2"), _param("p4", label="R4")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"] + assert result.parameters[1].label.en_US == "R2" + + +class TestWorkflowProviderToUserProvider: + """Tests for ToolTransformService.workflow_provider_to_user_provider.""" + + @staticmethod + def _make_controller(provider_id="provider_123", **identity_overrides): + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + defaults = { + "author": "test_author", + "name": "test_workflow_tool", + "description": I18nObject(en_US="Test description"), + "icon": '{"type": "emoji", "content": "🔧"}', + "icon_dark": None, + "label": I18nObject(en_US="Test Workflow Tool"), + } + defaults.update(identity_overrides) + identity = ToolProviderIdentity(**defaults) + entity = ToolProviderEntity(identity=identity) + return WorkflowToolProviderController(entity=entity, provider_id=provider_id) + + def test_with_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1", "l2"], + workflow_app_id="app_123", + ) + + assert isinstance(result, ToolProviderApiEntity) + assert result.id == "provider_123" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_123" + assert result.labels == ["l1", "l2"] + assert result.is_team_authorization is True + + def test_without_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1"], + ) + + assert result.workflow_app_id is None + + def test_workflow_app_id_none_explicit(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=None, + workflow_app_id=None, + ) + + assert result.workflow_app_id is None + assert result.labels == [] + + def test_preserves_other_fields(self): + ctrl = self._make_controller( + "provider_456", + author="another_author", + name="another_workflow_tool", + description=I18nObject(en_US="Another desc", zh_Hans="Another desc"), + icon='{"type": "emoji", "content": "⚙️"}', + icon_dark='{"type": "emoji", "content": "🔧"}', + label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"), + ) + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["automation"], + workflow_app_id="app_456", + ) + + assert result.id == "provider_456" + assert result.author == "another_author" + assert result.name == "another_workflow_tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_456" + assert result.is_team_authorization is True + assert result.allow_delete is True diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e54..21a1975879 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -25,7 +25,7 @@ class TestWorkflowToolManageService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, patch( "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" @@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService: # After the fix, this should always be 0 # For now, we document that the record may exist, demonstrating the bug # assert tool_count == 0 # Expected after fix + + def test_delete_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of a workflow tool.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + tool_name = fake.unique.word() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + ) + + tool = ( + db_session_with_containers.query(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name) + .first() + ) + assert tool is not None + + result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first() + ) + assert deleted is None + + def test_list_tenant_workflow_tools_empty( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing workflow tools when none exist returns empty list.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id) + + assert result == [] + + def test_get_workflow_tool_by_tool_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_tool_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_get_workflow_tool_by_app_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_app_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_list_single_workflow_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that list_single_workflow_tools raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4()) + + def test_create_workflow_tool_with_labels( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that labels are forwarded to ToolLabelManager when provided.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.unique.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + labels=["label-1", "label-2"], + ) + + assert result == {"result": "success"} + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index c3fe6a2950..ce5c2bd162 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import json -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy.orm import Session from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, DatasetEntity, DatasetRetrieveConfigEntity, ExternalDataVariableEntity, @@ -13,10 +18,11 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant -from models.api_based_extension import APIBasedExtension +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow.workflow_converter import WorkflowConverter @@ -548,3 +554,198 @@ class TestWorkflowConverter: # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT), + VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH), + VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT), + ] + + +class TestConvertToHttpRequestNodeVariants: + """Tests for chatbot vs workflow differences in HTTP request node conversion.""" + + @staticmethod + def _setup(app_mode, default_variables): + app_model = App( + tenant_id="tenant_id", + mode=app_mode, + name="test", + icon_type="emoji", + icon="🤖", + icon_background="#FFFFFF", + ) + + ext = APIBasedExtension(tenant_id="tenant_id", name="api-1", api_key="enc", api_endpoint="https://dify.ai") + ext.id = "ext_id" + + converter = WorkflowConverter() + converter._get_api_based_extension = MagicMock(return_value=ext) + + from core.helper import encrypter + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + ext_vars = [ + ExternalDataVariableEntity( + variable="external_variable", type="api", config={"api_based_extension_id": "ext_id"} + ) + ] + nodes, _ = converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=ext_vars, + ) + return nodes + + def test_chatbot_query_uses_sys_query(self, default_variables): + nodes = self._setup(AppMode.CHAT, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "{{#sys.query#}}" + assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY + assert nodes[1]["data"]["type"] == "code" + + def test_workflow_query_is_empty(self, default_variables): + nodes = self._setup(AppMode.WORKFLOW, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "" + + +class TestConvertToKnowledgeRetrievalNodeVariants: + """Tests for chatbot vs workflow differences in knowledge retrieval node.""" + + @staticmethod + def _dataset_config(query_variable=None): + return DatasetEntity( + dataset_ids=["ds1", "ds2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + + @staticmethod + def _model_config(): + return ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + def test_chatbot_uses_sys_query(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.ADVANCED_CHAT, + dataset_config=self._dataset_config(), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["sys", "query"] + + def test_workflow_uses_start_variable(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.WORKFLOW, + dataset_config=self._dataset_config(query_variable="query"), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["start", "query"] + + +class TestConvertToLlmNode: + """Tests for LLM node conversion across model modes and prompt types.""" + + @staticmethod + def _model_config(model, mode): + return ModelConfigEntity( + provider="openai", + model=model, + mode=mode.value, + parameters={}, + stop=[], + ) + + @staticmethod + def _graph(default_variables): + start = WorkflowConverter()._convert_to_start_node(default_variables) + return {"nodes": [start], "edges": []} + + def test_simple_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["type"] == "llm" + assert node["data"]["model"]["mode"] == LLMMode.CHAT.value + assert node["data"]["context"]["enabled"] is False + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"][0]["text"] == expected + + def test_simple_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["model"]["mode"] == LLMMode.COMPLETION.value + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"]["text"] == expected + + def test_advanced_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are helpful named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], list) + assert len(node["data"]["prompt_template"]) == 3 + + def test_advanced_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are helpful named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", assistant="Assistant" + ), + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], dict) + assert "text" in node["data"]["prompt_template"] diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py new file mode 100644 index 0000000000..29e1e240b4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -0,0 +1,158 @@ +"""Testcontainers integration tests for WorkflowService.delete_workflow.""" + +import json +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + + +class TestWorkflowDeletion: + def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + session.add(tenant) + session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"wf_del_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + session.add(join) + session.flush() + return tenant, account + + def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App: + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + workflow_id=workflow_id, + ) + session.add(app) + session.flush() + return app + + def _create_workflow( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0" + ) -> Workflow: + workflow = Workflow( + id=str(uuid4()), + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + session.add(workflow) + session.flush() + return workflow + + def _create_tool_provider( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str + ) -> WorkflowToolProvider: + provider = WorkflowToolProvider( + name=f"tool-{uuid4()}", + label=f"Tool {uuid4()}", + icon="wrench", + app_id=app.id, + version=version, + user_id=account.id, + tenant_id=tenant.id, + description="test tool provider", + ) + session.add(provider) + session.flush() + return provider + + def test_delete_workflow_success(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + db_session_with_containers.commit() + workflow_id = workflow.id + + service = WorkflowService(sessionmaker(bind=db.engine)) + result = service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id + ) + + assert result is True + db_session_with_containers.expire_all() + assert db_session_with_containers.get(Workflow, workflow_id) is None + + def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="draft" + ) + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(DraftWorkflowDeletionError): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + # Point app to this workflow + app.workflow_id = workflow.id + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="currently in use by app"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0") + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="published as a tool"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index af9e8d0b2c..4dab895135 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -4,7 +4,7 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bf..4b04c1accb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 210d9eb39e..6cbbe43137 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask: created_from=DocumentCreatedFrom.WEB, created_by=account.id, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) @@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Execute the task with non-existent dataset - batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) + batch_clean_document_task( + document_ids=[document_id], + dataset_id=dataset_id, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + file_ids=[], + ) # Verify that no index processing occurred mock_external_service_dependencies["index_processor"].clean.assert_not_called() @@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask: account = self._create_test_account(db_session_with_containers) # Test different doc_form types - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 202ccb0098..f9ae33b32f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,6 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -53,7 +54,10 @@ class TestBatchCreateSegmentToIndexTask: """Mock setup for external service dependencies.""" with ( patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch( + "tasks.batch_create_segment_to_index_task.ModelManager.for_tenant", + autospec=True, + ) as mock_model_manager, patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns @@ -141,7 +145,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, @@ -179,7 +183,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ) @@ -221,17 +225,17 @@ class TestBatchCreateSegmentToIndexTask: return upload_file - def _create_test_csv_content(self, content_type="text_model"): + def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX): """ Helper method to create test CSV content. Args: - content_type: Type of content to create ("text_model" or "qa_model") + content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX) Returns: str: CSV content as string """ - if content_type == "qa_model": + if content_type == IndexStructureType.QA_INDEX: csv_content = "content,answer\n" csv_content += "This is the first segment content,This is the first answer\n" csv_content += "This is the second segment content,This is the second answer\n" @@ -264,7 +268,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] @@ -451,7 +455,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Archived document @@ -467,7 +471,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Document with incomplete indexing @@ -483,7 +487,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), ] @@ -655,7 +659,7 @@ class TestBatchCreateSegmentToIndexTask: db_session_with_containers.commit() # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1cd698b870..1dd37fbc92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,6 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -153,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -192,7 +193,7 @@ class TestCleanDatasetTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=100, created_at=datetime.now(), updated_at=datetime.now(), @@ -869,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index a2a190fd69..926c839c8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask: name=f"Notion Page {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", # Set doc_form to ensure dataset.doc_form works + doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works doc_language="en", indexing_status=IndexingStatus.COMPLETED, ) @@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask: # Test different index types # Note: Only testing text_model to avoid dependency on external services - index_types = ["text_model"] + index_types = [IndexStructureType.PARAGRAPH_INDEX] for index_type in index_types: # Create dataset (doc_form will be set via document creation) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 132f43c320..9f8e37fc9e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -120,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, @@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask: - Processing completes successfully for different forms """ # Arrange: Test different doc_forms - doc_forms = ["qa_model", "text_model", "web_model"] + doc_forms = [ + IndexStructureType.QA_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + ] for doc_form in doc_forms: # Create fresh test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc7011..13ea94348a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index e80b37ac1b..d457b59d58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="qa_index", + doc_form=IndexStructureType.QA_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type - mock_index_processor_factory.assert_called_once_with("qa_index") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type - mock_index_processor_factory.assert_called_once_with("text_model") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask: name=f"Test Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask: name="Enabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask: name="Disabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped @@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask: name="Active Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask: name="Archived Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask: name="Completed Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask: name="Incomplete Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9c..8a69707b38 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index da42fc7167..5bdf7d1389 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,6 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -99,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask: dataset: Dataset, tenant: Tenant, account: Account, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """ Helper method to create a test document. @@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask: - Index processor clean method is called correctly """ # Test different document forms - doc_forms = ["text_model", "qa_model", "table_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Arrange: Create test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 4bc9bb4749..3e9a0c8f7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -102,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", @@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask: document.indexing_status = "completed" document.enabled = True document.archived = False - document.doc_form = "text_model" # Use text_model form for testing + document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing document.doc_language = "en" db_session_with_containers.add(document) db_session_with_containers.commit() @@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask: segment_ids = [segment.id for segment in segments] # Test different document forms - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Update document form diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 6a17a19a54..d4021143ef 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,6 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -56,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory: created_by=created_by, indexing_status=indexing_status, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b07285..cf1a8666f3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 2fbea1388c..d94abf2b40 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -63,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index f1f5a4b105..6a8e186958 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -109,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -244,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=dataset.created_by, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) extra_documents.append(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8..e2f35067e3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 0876a39f82..d341c5ce99 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -9,17 +9,17 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient @@ -79,9 +79,9 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id=account.id), + MemberRecipient(reference_id=account.id), ExternalRecipient(email="external@example.com"), ], ), @@ -96,9 +96,8 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=app_id) params = FormCreateParams( - app_id=app_id, workflow_execution_id=workflow_execution_id, node_id="node-1", form_config=node_data, diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 5bded4d670..9a7507a2f9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,9 +4,9 @@ from unittest.mock import ANY, call, patch import pytest from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index ca76fa0a4b..b9f513a6d0 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -27,9 +27,9 @@ import pytest from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 4ea8d8c1c7..8854ef5e04 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -23,7 +23,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 3f75fd2851..55873b06a8 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine): def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): """ - Helper to set up the mock DB query chain for tenant/account authentication. + Helper to set up the mock DB execute chain for tenant/account authentication. - This configures the mock to return (tenant, account) for the join query used - by validate_app_token and validate_dataset_token decorators. + This configures the mock to return (tenant, account) for the + db.session.execute(select(...).join().join().where()).one_or_none() + query used by validate_app_token decorator. Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return mock_account: Mock account object to return """ - query = mock_db.session.query.return_value - join_chain = query.join.return_value.join.return_value - where_chain = join_chain.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): """ - Helper to set up the mock DB query chain for dataset tenant authentication. + Helper to set up the mock DB execute chain for dataset tenant authentication. - This configures the mock to return (tenant, tenant_account) for the where chain + This configures the mock to return (tenant, tenant_account) for the + db.session.execute(select(...).where().where().where().where()).one_or_none() query used by validate_dataset_token decorator. Args: @@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): mock_tenant: Mock tenant object to return mock_ta: Mock tenant account object to return """ - query = mock_db.session.query.return_value - where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fe..1d1e119fd6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" @@ -281,12 +328,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +350,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 021e9a0784..2d218dac7e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -20,7 +20,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c332..11b3b3470d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc..f588ab261d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da3..e64c508b82 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py deleted file mode 100644 index 3ffa53b6db..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_message.py +++ /dev/null @@ -1,320 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.exceptions import InternalServerError, NotFound -from werkzeug.local import LocalProxy - -from controllers.console.app.error import ( - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.app.message import ( - ChatMessageListApi, - ChatMessagesQuery, - FeedbackExportQuery, - MessageAnnotationCountApi, - MessageApi, - MessageFeedbackApi, - MessageFeedbackExportApi, - MessageFeedbackPayload, - MessageSuggestedQuestionApi, -) -from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from models import App, AppMode -from services.errors.conversation import ConversationNotExistsError -from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -import contextlib - - -@contextlib.contextmanager -def setup_test_context( - test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None -): - with ( - patch("extensions.ext_database.db") as mock_db, - patch("controllers.console.app.wraps.db", mock_db), - patch("controllers.console.wraps.db", mock_db), - patch("controllers.console.app.message.db", mock_db), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - # Set up a generic query mock that usually returns mock_app_model when getting app - app_query_mock = MagicMock() - app_query_mock.filter.return_value.first.return_value = mock_app_model - app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model - app_query_mock.where.return_value.first.return_value = mock_app_model - app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model - - data_query_mock = MagicMock() - - def query_side_effect(*args, **kwargs): - if args and hasattr(args[0], "__name__") and args[0].__name__ == "App": - return app_query_mock - return data_query_mock - - mock_db.session.query.side_effect = query_side_effect - mock_db.data_query = data_query_mock - - # Let the caller override the stat db query logic - proxy_mock = LocalProxy(lambda: mock_account) - - query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) - full_path = f"{route_path}?{query_string}" if qs else route_path - - with ( - patch("libs.login.current_user", proxy_mock), - patch("flask_login.current_user", proxy_mock), - patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), - ): - with test_app.test_request_context(full_path, method=method, json=payload): - request.view_args = {"app_id": "app_123"} - - if "suggested-questions" in route_path: - # simplistic extraction for message_id - parts = route_path.split("chat-messages/") - if len(parts) > 1: - request.view_args["message_id"] = parts[1].split("/")[0] - elif "messages/" in route_path and "chat-messages" not in route_path: - parts = route_path.split("messages/") - if len(parts) > 1: - request.view_args["message_id"] = parts[1].split("/")[0] - - api_instance = endpoint_class() - - # Check if it has a dispatch_request or method - if hasattr(api_instance, method.lower()): - yield api_instance, mock_db, request.view_args - - -class TestMessageValidators: - def test_chat_messages_query_validators(self): - # Test empty_to_none - assert ChatMessagesQuery.empty_to_none("") is None - assert ChatMessagesQuery.empty_to_none("val") == "val" - - # Test validate_uuid - assert ChatMessagesQuery.validate_uuid(None) is None - assert ( - ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") - == "123e4567-e89b-12d3-a456-426614174000" - ) - - def test_message_feedback_validators(self): - assert ( - MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") - == "123e4567-e89b-12d3-a456-426614174000" - ) - - def test_feedback_export_validators(self): - assert FeedbackExportQuery.parse_bool(None) is None - assert FeedbackExportQuery.parse_bool(True) is True - assert FeedbackExportQuery.parse_bool("1") is True - assert FeedbackExportQuery.parse_bool("0") is False - assert FeedbackExportQuery.parse_bool("off") is False - - with pytest.raises(ValueError): - FeedbackExportQuery.parse_bool("invalid") - - -class TestMessageEndpoints: - def test_chat_message_list_not_found(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - ChatMessageListApi, - "/apps/app_123/chat-messages", - "GET", - mock_account, - mock_app_model, - qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, - ) as (api, mock_db, v_args): - mock_db.data_query.where.return_value.first.return_value = None - - with pytest.raises(NotFound): - api.get(**v_args) - - def test_chat_message_list_success(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - ChatMessageListApi, - "/apps/app_123/chat-messages", - "GET", - mock_account, - mock_app_model, - qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1}, - ) as (api, mock_db, v_args): - mock_conv = MagicMock() - mock_conv.id = "123e4567-e89b-12d3-a456-426614174000" - mock_msg = MagicMock() - mock_msg.id = "msg_123" - mock_msg.feedbacks = [] - mock_msg.annotation = None - mock_msg.annotation_hit_history = None - mock_msg.agent_thoughts = [] - mock_msg.message_files = [] - mock_msg.extra_contents = [] - mock_msg.message = {} - mock_msg.message_metadata_dict = {} - - # mock returns - q_mock = mock_db.data_query - q_mock.where.return_value.first.side_effect = [mock_conv] - q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg] - mock_db.session.scalar.return_value = False - - resp = api.get(**v_args) - assert resp["limit"] == 1 - assert resp["has_more"] is False - assert len(resp["data"]) == 1 - - def test_message_feedback_not_found(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - MessageFeedbackApi, - "/apps/app_123/feedbacks", - "POST", - mock_account, - mock_app_model, - payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, - ) as (api, mock_db, v_args): - mock_db.data_query.where.return_value.first.return_value = None - - with pytest.raises(NotFound): - api.post(**v_args) - - def test_message_feedback_success(self, app, mock_account, mock_app_model): - payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"} - with setup_test_context( - app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload - ) as (api, mock_db, v_args): - mock_msg = MagicMock() - mock_msg.admin_feedback = None - mock_db.data_query.where.return_value.first.return_value = mock_msg - - resp = api.post(**v_args) - assert resp == {"result": "success"} - - def test_message_annotation_count(self, app, mock_account, mock_app_model): - with setup_test_context( - app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - mock_db.data_query.where.return_value.count.return_value = 5 - - resp = api.get(**v_args) - assert resp == {"count": 5} - - @patch("controllers.console.app.message.MessageService") - def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model): - mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"] - - with setup_test_context( - app, - MessageSuggestedQuestionApi, - "/apps/app_123/chat-messages/msg_123/suggested-questions", - "GET", - mock_account, - mock_app_model, - ) as (api, mock_db, v_args): - resp = api.get(**v_args) - assert resp == {"data": ["q1", "q2"]} - - @pytest.mark.parametrize( - ("exc", "expected_exc"), - [ - (MessageNotExistsError, NotFound), - (ConversationNotExistsError, NotFound), - (ProviderTokenNotInitError, ProviderNotInitializeError), - (QuotaExceededError, ProviderQuotaExceededError), - (ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError), - (SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError), - (Exception, InternalServerError), - ], - ) - @patch("controllers.console.app.message.MessageService") - def test_message_suggested_questions_errors( - self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model - ): - mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc() - - with setup_test_context( - app, - MessageSuggestedQuestionApi, - "/apps/app_123/chat-messages/msg_123/suggested-questions", - "GET", - mock_account, - mock_app_model, - ) as (api, mock_db, v_args): - with pytest.raises(expected_exc): - api.get(**v_args) - - @patch("services.feedback_service.FeedbackService.export_feedbacks") - def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model): - mock_export.return_value = {"exported": True} - - with setup_test_context( - app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - resp = api.get(**v_args) - assert resp == {"exported": True} - - def test_message_api_get_success(self, app, mock_account, mock_app_model): - with setup_test_context( - app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - mock_msg = MagicMock() - mock_msg.id = "msg_123" - mock_msg.feedbacks = [] - mock_msg.annotation = None - mock_msg.annotation_hit_history = None - mock_msg.agent_thoughts = [] - mock_msg.message_files = [] - mock_msg.extra_contents = [] - mock_msg.message = {} - mock_msg.message_metadata_dict = {} - - mock_db.data_query.where.return_value.first.return_value = mock_msg - - resp = api.get(**v_args) - assert resp["id"] == "msg_123" diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c7..a0e2edb8cf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py deleted file mode 100644 index beba23385d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_statistic.py +++ /dev/null @@ -1,275 +0,0 @@ -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.statistic import ( - AverageResponseTimeStatistic, - AverageSessionInteractionStatistic, - DailyConversationStatistic, - DailyMessageStatistic, - DailyTerminalsStatistic, - DailyTokenCostStatistic, - TokensPerSecondStatistic, - UserSatisfactionRateStatistic, -) -from models import App, AppMode - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context( - test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None) -): - with ( - patch("controllers.console.app.statistic.db") as mock_db_stat, - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch( - "controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123") - ), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_conn = MagicMock() - mock_conn.execute.return_value = mock_rs - - mock_begin = MagicMock() - mock_begin.__enter__.return_value = mock_conn - mock_db_stat.engine.begin.return_value = mock_begin - - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method="GET"): - request.view_args = {"app_id": "app_123"} - api_instance = endpoint_class() - response = api_instance.get(app_id="app_123") - return response - - -class TestStatisticEndpoints: - def test_daily_message_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["message_count"] == 10 - - def test_daily_conversation_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.conversation_count = 5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyConversationStatistic, - "/apps/app_123/statistics/daily-conversations", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["conversation_count"] == 5 - - def test_daily_terminals_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.terminal_count = 2 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTerminalsStatistic, - "/apps/app_123/statistics/daily-end-users", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["terminal_count"] == 2 - - def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.token_count = 100 - mock_row.total_price = Decimal("0.02") - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTokenCostStatistic, - "/apps/app_123/statistics/token-costs", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["token_count"] == 100 - assert response.json["data"][0]["total_price"] == "0.02" - - def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.interactions = Decimal("3.523") - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageSessionInteractionStatistic, - "/apps/app_123/statistics/average-session-interactions", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["interactions"] == 3.52 - - def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 100 - mock_row.feedback_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - UserSatisfactionRateStatistic, - "/apps/app_123/statistics/user-satisfaction-rate", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["rate"] == 100.0 - - def test_average_response_time_statistic(self, app, mock_account, mock_app_model): - mock_app_model.mode = AppMode.COMPLETION - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.latency = 1.234 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageResponseTimeStatistic, - "/apps/app_123/statistics/average-response-time", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["latency"] == 1234.0 - - def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.tokens_per_second = 15.5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - TokensPerSecondStatistic, - "/apps/app_123/statistics/tokens-per-second", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["tps"] == 15.5 - - @patch("controllers.console.app.statistic.parse_time_range") - def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model): - mock_parse.side_effect = ValueError("Invalid time") - - from werkzeug.exceptions import BadRequest - - with pytest.raises(BadRequest): - setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=invalid&end=invalid", - mock_account, - mock_app_model, - [], - ) - - @patch("controllers.console.app.statistic.parse_time_range") - def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model): - import datetime - - start = datetime.datetime.now() - end = datetime.datetime.now() - mock_parse.return_value = (start, end) - - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=something&end=something", - mock_account, - mock_app_model, - [], - ) - assert response.status_code == 200 - mock_parse.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 0e22db9f9b..341efc05ca 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -9,8 +9,8 @@ from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File def _unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py deleted file mode 100644 index 9b5d47c208..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py +++ /dev/null @@ -1,313 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.error import DraftWorkflowNotExist -from controllers.console.app.workflow_draft_variable import ( - ConversationVariableCollectionApi, - EnvironmentVariableCollectionApi, - NodeVariableCollectionApi, - SystemVariableCollectionApi, - VariableApi, - VariableResetApi, - WorkflowVariableCollectionApi, -) -from controllers.web.error import InvalidArgumentError, NotFoundError -from models import App, AppMode -from models.enums import DraftVariableType - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.WORKFLOW - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None): - with ( - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch("controllers.console.app.workflow_draft_variable.db"), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method=method, json=payload): - request.view_args = {"app_id": "app_123"} - # extract node_id or variable_id from path manually since view_args overrides - if "nodes/" in route_path: - request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0] - if "variables/" in route_path: - # simplistic extraction - parts = route_path.split("variables/") - if len(parts) > 1 and parts[1] and parts[1] != "reset": - request.view_args["variable_id"] = parts[1].split("/")[0] - - api_instance = endpoint_class() - # we just call dispatch_request to avoid manual argument passing - if hasattr(api_instance, method.lower()): - func = getattr(api_instance, method.lower()) - return func(**request.view_args) - - -class TestWorkflowDraftVariableEndpoints: - @staticmethod - def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock: - class DummyValueType: - def exposed_type(self): - return DraftVariableType.NODE - - mock_var = MagicMock() - mock_var.app_id = "app_123" - mock_var.id = "var_123" - mock_var.name = "test_var" - mock_var.description = "" - mock_var.get_variable_type.return_value = variable_type - mock_var.get_selector.return_value = [] - mock_var.value_type = DummyValueType() - mock_var.edited = False - mock_var.visible = True - mock_var.file_id = None - mock_var.variable_file = None - mock_var.is_truncated.return_value = False - mock_var.get_value.return_value.model_copy.return_value.value = "test_value" - return mock_var - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_workflow_variable_collection_get_success( - self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model - ): - mock_wf_srv.return_value.is_workflow_exist.return_value = True - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList( - variables=[], total=0 - ) - - resp = setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables?page=1&limit=20", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": [], "total": 0} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.is_workflow_exist.return_value = False - - with pytest.raises(DraftWorkflowNotExist): - setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables", - "GET", - mock_account, - mock_app_model, - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): - resp = setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables", - "DELETE", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[]) - resp = setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/node_123/variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model): - with pytest.raises(InvalidArgumentError): - setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/sys/variables", - "GET", - mock_account, - mock_app_model, - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): - resp = setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/node_123/variables", - "DELETE", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model - ) - assert resp["id"] == "var_123" - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = None - - with pytest.raises(NotFoundError): - setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, - VariableApi, - "/apps/app_123/workflows/draft/variables/var_123", - "PATCH", - mock_account, - mock_app_model, - payload={"name": "new_name"}, - ) - assert resp["id"] == "var_123" - mock_draft_srv.return_value.update_variable.assert_called_once() - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model - ) - assert resp.status_code == 204 - mock_draft_srv.return_value.delete_variable.assert_called_once() - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - mock_draft_srv.return_value.reset_variable.return_value = None # means no content - - resp = setup_test_context( - app, - VariableResetApi, - "/apps/app_123/workflows/draft/variables/var_123/reset", - "PUT", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[]) - - resp = setup_test_context( - app, - ConversationVariableCollectionApi, - "/apps/app_123/workflows/draft/conversation-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model): - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[]) - - resp = setup_test_context( - app, - SystemVariableCollectionApi, - "/apps/app_123/workflows/draft/system-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf = MagicMock() - mock_wf.environment_variables = [] - mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf - - resp = setup_test_context( - app, - EnvironmentVariableCollectionApi, - "/apps/app_123/workflows/draft/environment-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 83601dc1b9..c4a8148446 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -10,10 +10,10 @@ from flask import Flask from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun @@ -67,7 +67,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte actions=[UserAction(id="approve", title="Approve")], node_id="node-1", node_title="Ask Name", - form_token="backstage-token", ) pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) @@ -78,6 +77,11 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte "create_api_workflow_run_repository", lambda *_, **__: repo, ) + monkeypatch.setattr( + workflow_run_module, + "_load_form_tokens_by_form_id", + lambda _form_ids: {"form-1": "backstage-token"}, + ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492da..b5f751f5a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f34702a257..559b5fea09 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,9 +13,9 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -310,13 +310,12 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file.enums import FileTransferMethod, FileType + from graphon.file.models import File # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", @@ -368,13 +367,12 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file.enums import FileTransferMethod, FileType + from graphon.file.models import File # Create a File object with REMOTE_URL transfer method test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py deleted file mode 100644 index bc4c7e0993..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ /dev/null @@ -1,209 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask - -from controllers.console.auth.data_source_bearer_auth import ( - ApiKeyAuthDataSource, - ApiKeyAuthDataSourceBinding, - ApiKeyAuthDataSourceBindingDelete, -) -from controllers.console.auth.error import ApiKeyAuthFailedError - - -class TestApiKeyAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_binding = MagicMock() - mock_binding.id = "bind_123" - mock_binding.category = "api_key" - mock_binding.provider = "custom_provider" - mock_binding.disabled = False - mock_binding.created_at.timestamp.return_value = 1620000000 - mock_binding.updated_at.timestamp.return_value = 1620000001 - - mock_get_list.return_value = [mock_binding] - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 1 - assert response["sources"][0]["provider"] == "custom_provider" - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_get_list.return_value = None - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 0 - - -class TestApiKeyAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - response = api_instance.post() - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_validate.assert_called_once() - mock_create.assert_called_once() - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_create.side_effect = ValueError("Invalid structure") - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): - api_instance.post() - - -class TestApiKeyAuthDataSourceBindingDelete: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth") - def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBindingDelete() - response = api_instance.delete("binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 204 - mock_delete.assert_called_once_with("tenant_123", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py deleted file mode 100644 index f369565946..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py +++ /dev/null @@ -1,192 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.local import LocalProxy - -from controllers.console.auth.data_source_oauth import ( - OAuthDataSource, - OAuthDataSourceBinding, - OAuthDataSourceCallback, - OAuthDataSourceSync, -) - - -class TestOAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("flask_login.current_user") - @patch("libs.login.current_user") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) - def test_get_oauth_url_successful( - self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app - ): - mock_oauth_provider = MagicMock() - mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" - mock_get_providers.return_value = {"notion": mock_oauth_provider} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - mock_libs_user.return_value = mock_account - mock_flask_user.return_value = mock_account - - # also patch current_account_with_tenant - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("notion") - - assert response[0]["data"] == "http://oauth.provider/auth" - assert response[1] == 200 - mock_oauth_provider.get_authorization_url.assert_called_once() - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("flask_login.current_user") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("unknown_provider") - - assert response[0]["error"] == "Invalid provider" - assert response[1] == 400 - - -class TestOAuthDataSourceCallback: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_successful(self, mock_get_providers, app): - provider_mock = MagicMock() - mock_get_providers.return_value = {"notion": provider_mock} - - with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("notion") - - assert response.status_code == 302 - assert "code=mock_code" in response.location - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_missing_code(self, mock_get_providers, app): - provider_mock = MagicMock() - mock_get_providers.return_value = {"notion": provider_mock} - - with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("notion") - - assert response.status_code == 302 - assert "error=Access denied" in response.location - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_invalid_provider(self, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("invalid") - - assert response[0]["error"] == "Invalid provider" - assert response[1] == 400 - - -class TestOAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_get_binding_successful(self, mock_get_providers, app): - mock_provider = MagicMock() - mock_provider.get_access_token.return_value = None - mock_get_providers.return_value = {"notion": mock_provider} - - with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"): - api_instance = OAuthDataSourceBinding() - response = api_instance.get("notion") - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_provider.get_access_token.assert_called_once_with("auth_code_123") - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_get_binding_missing_code(self, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"): - api_instance = OAuthDataSourceBinding() - response = api_instance.get("notion") - - assert response[0]["error"] == "Invalid code" - assert response[1] == 400 - - -class TestOAuthDataSourceSync: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app): - mock_provider = MagicMock() - mock_provider.sync_data_source.return_value = None - mock_get_providers.return_value = {"notion": mock_provider} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSourceSync() - # The route pattern uses , so we just pass a string for unit testing - response = api_instance.get("notion", "binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_provider.sync_data_source.assert_called_once_with("binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py deleted file mode 100644 index fc5663e72d..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py +++ /dev/null @@ -1,417 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.exceptions import BadRequest, NotFound - -from controllers.console.auth.oauth_server import ( - OAuthServerAppApi, - OAuthServerUserAccountApi, - OAuthServerUserAuthorizeApi, - OAuthServerUserTokenApi, -) - - -class TestOAuthServerAppApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.redirect_uris = ["http://localhost/callback"] - oauth_app.app_icon = "icon_url" - oauth_app.app_label = "Test App" - oauth_app.scope = "read,write" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - response = api_instance.post() - - assert response["app_icon"] == "icon_url" - assert response["app_label"] == "Test App" - assert response["scope"] == "read,write" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_client_id(self, mock_get_app, mock_db, app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = None - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(NotFound, match="client_id is invalid"): - api_instance.post() - - -class TestOAuthServerUserAuthorizeApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - oauth_app = MagicMock() - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.current_account_with_tenant") - @patch("controllers.console.wraps.current_account_with_tenant") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code") - @patch("libs.login.check_csrf_token") - def test_successful_authorize( - self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app - ): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.id = "user_123" - from models.account import AccountStatus - - mock_account.status = AccountStatus.ACTIVE - - mock_current.return_value = (mock_account, MagicMock()) - mock_wrap_current.return_value = (mock_account, MagicMock()) - - mock_sign.return_value = "auth_code_123" - - with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): - with patch("libs.login.current_user", mock_account): - api_instance = OAuthServerUserAuthorizeApi() - response = api_instance.post() - - assert response["code"] == "auth_code_123" - mock_sign.assert_called_once_with("test_client_id", "user_123") - - -class TestOAuthServerUserTokenApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.client_secret = "test_secret" - oauth_app.redirect_uris = ["http://localhost/callback"] - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("access_123", "refresh_123") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "access_123" - assert response["refresh_token"] == "refresh_123" - assert response["token_type"] == "Bearer" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="code is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "invalid_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="client_secret is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://invalid/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("new_access", "new_refresh") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "new_access" - assert response["refresh_token"] == "new_refresh" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "refresh_token", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="refresh_token is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "invalid_grant", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="invalid grant_type"): - api_instance.post() - - -class TestOAuthServerUserAccountApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.name = "Test User" - mock_account.email = "test@example.com" - mock_account.avatar = "avatar_url" - mock_account.interface_language = "en-US" - mock_account.timezone = "UTC" - mock_validate.return_value = mock_account - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer valid_access_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response["name"] == "Test User" - assert response["email"] == "test@example.com" - assert response["avatar"] == "avatar_url" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Authorization header is required" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "InvalidFormat"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Basic something"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "token_type is invalid" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer "}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_validate.return_value = None - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer invalid_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "access_token or client_id is invalid" diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9014edc39e..5136922e88 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -17,7 +17,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index b4c0903f63..63950736c5 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -14,8 +14,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor RagPipelineVariableResetApi, ) from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 472d133349..aa7c3c7fbd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -26,6 +26,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineWorkflowLastRunApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from libs.datetime_utils import naive_utc_now from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -372,7 +373,7 @@ class TestPublishedPipelineApis: workflow = MagicMock( id="w1", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) session = MagicMock() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adf..d841f67f9b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import ( DataSourceNotionDocumentSyncApi, DataSourceNotionListApi, ) +from core.rag.index_processor.constant.index_type import IndexStructureType def unwrap(func): @@ -343,7 +344,7 @@ class TestDataSourceNotionApi: } ], "process_rule": {"rules": {}}, - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 0ee76e504b..8555900f4e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import ( from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import ApiToken, UploadFile @@ -416,7 +417,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -520,7 +521,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -579,7 +580,7 @@ class TestDatasetApiGet: "get_dataset_partial_member_list", return_value=partial_members, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi: }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } @@ -1475,8 +1476,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), ): response, status = method(api, "dataset-1") @@ -1525,13 +1526,6 @@ class TestDatasetIndexingStatusApi: document.error = None document.stopped_at = None - # First count = completed segments, second = total segments - query_mock = MagicMock() - query_mock.where.side_effect = [ - MagicMock(count=lambda: 2), - MagicMock(count=lambda: 5), - ] - with ( app.test_request_context("/"), patch( @@ -1543,8 +1537,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets.db.session.scalar", + side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") @@ -1590,8 +1584,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken.generate_api_key", @@ -1624,8 +1618,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: @@ -1652,8 +1646,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", @@ -1680,8 +1674,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f23dd5b44a..ce2278de4f 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType, IndexingStatus @@ -66,7 +67,7 @@ def document(): indexing_status=IndexingStatus.INDEXING, data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, archived=False, is_paused=False, dataset_process_rule=None, @@ -139,8 +140,8 @@ class TestDatasetDocumentListApi: return_value=pagination, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=2, ), patch( "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", @@ -699,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=log, ), ): response, status = method(api, "ds-1", "doc-1") @@ -765,8 +764,8 @@ class TestDocumentGenerateSummaryApi: summary_index_setting={"enable": True}, ) - doc1 = MagicMock(id="doc-1", doc_form="qa_model") - doc2 = MagicMock(id="doc-2", doc_form="text") + doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX) + doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX) payload = {"document_list": ["doc-1", "doc-2"]} @@ -822,19 +821,16 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) - query_mock = MagicMock() - query_mock.where.return_value.first.return_value = None - with ( app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -849,7 +845,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -862,10 +858,8 @@ class TestDocumentIndexingEstimateApi: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", @@ -973,7 +967,7 @@ class TestDocumentBatchIndexingEstimateApi: "mode": "single", "only_main_content": True, }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1001,7 +995,7 @@ class TestDocumentBatchIndexingEstimateApi: "notion_page_id": "p1", "type": "page", }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1024,7 +1018,7 @@ class TestDocumentBatchIndexingEstimateApi: indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): @@ -1238,12 +1232,8 @@ class TestDocumentPermissionCases: return_value=None, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=lambda *a: MagicMock( - order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) - ) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=process_rule, ), ): result = method(api) @@ -1353,7 +1343,7 @@ class TestDocumentIndexingEdgeCases: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -1363,8 +1353,8 @@ class TestDocumentIndexingEdgeCases: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index e67e4daad9..693b06e95b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -1,4 +1,3 @@ -from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -24,6 +23,8 @@ from controllers.console.datasets.error import ( InvalidActionError, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType +from libs.datetime_utils import naive_utc_now from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -53,8 +54,8 @@ def _segment(): disabled_by=None, status="normal", created_by="u1", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), + created_at=naive_utc_now(), + updated_at=naive_utc_now(), updated_by="u1", indexing_at=None, completed_at=None, @@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() segment.id = "seg-1" @@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() @@ -525,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -620,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -705,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -737,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(ValueError): @@ -769,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -830,8 +831,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -879,8 +880,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -923,11 +924,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -969,11 +967,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -1179,8 +1174,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(NotFound): @@ -1214,8 +1209,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e7ae37ae45..e4acd91b76 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -20,7 +20,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py index 90f00711c1..e358435de4 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -26,12 +26,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = None - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=None, ) with pytest.raises(PipelineNotFoundError): @@ -51,12 +48,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -76,12 +70,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -100,18 +91,15 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - def where_side_effect(*args, **kwargs): - assert args[0].right.value == "123" - return Mock(first=lambda: pipeline) - - mock_query = Mock() - mock_query.where.side_effect = where_side_effect - - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + mock_scalar = mocker.patch( + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id=123) assert result is pipeline + # Verify the pipeline_id was cast to string in the where clause + stmt = mock_scalar.call_args[0][0] + where_clauses = stmt.whereclause.clauses + assert where_clauses[0].right.value == "123" diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 0afbc5a8f7..b4b57022e2 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -19,7 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 6b5c304884..145cc9cdd7 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -21,7 +21,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 5a03daecbc..03eadcdb4e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -25,7 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 769edc8d1c..e89b89c8b1 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -11,6 +11,7 @@ from controllers.console.tag.tags import ( TagListApi, TagUpdateDeleteApi, ) +from models.enums import TagType def unwrap(func): @@ -52,7 +53,7 @@ def tag(): tag = MagicMock() tag.id = "tag-1" tag.name = "test-tag" - tag.type = "knowledge" + tag.type = TagType.KNOWLEDGE return tag diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index c18dd044a7..2dff9c4037 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -8,6 +8,7 @@ from controllers.console.apikey import ( BaseApiKeyResource, _get_resource, ) +from models.enums import ApiTokenType @pytest.fixture @@ -45,14 +46,14 @@ def bypass_permissions(): class DummyApiKeyListResource(BaseApiKeyListResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" token_prefix = "app-" class DummyApiKeyResource(BaseApiKeyResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index f2e57eb65f..b2f949c6e2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -13,8 +13,8 @@ from flask import Flask from flask.views import MethodView from werkzeug.exceptions import Forbidden -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index af0c2c5594..168479af1e 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -13,7 +13,7 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index 43b8e1ac2e..f0d32f81fb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -14,8 +14,8 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index eb19243225..ce5fd1c466 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -90,8 +90,8 @@ class TestPluginListLatestVersionsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDebuggingKeyApi: @@ -120,8 +120,8 @@ class TestPluginDebuggingKeyApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginListApi: @@ -202,8 +202,9 @@ class TestPluginUploadFromPkgApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_pkg_mock.assert_not_called() @@ -365,8 +366,8 @@ class TestPluginListInstallationsFromIdsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromGithubApi: @@ -401,8 +402,8 @@ class TestPluginUploadFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromBundleApi: @@ -449,8 +450,9 @@ class TestPluginUploadFromBundleApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_bundle_mock.assert_not_called() @@ -495,8 +497,8 @@ class TestPluginInstallFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginInstallFromMarketplaceApi: @@ -532,8 +534,8 @@ class TestPluginInstallFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchMarketplacePkgApi: @@ -562,8 +564,8 @@ class TestPluginFetchMarketplacePkgApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchManifestApi: @@ -595,8 +597,8 @@ class TestPluginFetchManifestApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTasksApi: @@ -625,8 +627,8 @@ class TestPluginFetchInstallTasksApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTaskApi: @@ -655,8 +657,8 @@ class TestPluginFetchInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskApi: @@ -685,8 +687,8 @@ class TestPluginDeleteInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteAllInstallTaskItemsApi: @@ -717,8 +719,8 @@ class TestPluginDeleteAllInstallTaskItemsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskItemApi: @@ -747,8 +749,8 @@ class TestPluginDeleteInstallTaskItemApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "task1", "item1") + result = method(api, "task1", "item1") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromMarketplaceApi: @@ -790,8 +792,8 @@ class TestPluginUpgradeFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromGithubApi: @@ -839,8 +841,8 @@ class TestPluginUpgradeFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: @@ -894,8 +896,8 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginChangePreferencesApi: diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 94c3019d5e..44feacf2ad 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -4,7 +4,7 @@ from __future__ import annotations import builtins import importlib -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from types import ModuleType, SimpleNamespace from unittest.mock import MagicMock, patch @@ -18,7 +18,6 @@ if not hasattr(builtins, "MethodView"): _CONTROLLER_MODULE: ModuleType | None = None _WRAPS_MODULE: ModuleType | None = None -_CONTROLLER_PATCHERS: list[patch] = [] @contextmanager @@ -37,6 +36,14 @@ def app() -> Flask: @pytest.fixture def controller_module(monkeypatch: pytest.MonkeyPatch): + """ + Import the controller with auth decorators neutralized only during import. + + The imported view classes retain those no-op decorators after import, so we + can restore the original globals immediately and avoid leaking auth patches + into unrelated tests such as libs.login unit coverage. + """ + module_name = "controllers.console.workspace.tool_providers" global _CONTROLLER_MODULE if _CONTROLLER_MODULE is None: @@ -51,13 +58,12 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): ("controllers.console.wraps.is_admin_or_owner_required", _noop), ("controllers.console.wraps.enterprise_license_required", _noop), ] - for target, value in patch_targets: - patcher = patch(target, value) - patcher.start() - _CONTROLLER_PATCHERS.append(patcher) monkeypatch.setenv("DIFY_SETUP_READY", "true") - with _mock_db(): - _CONTROLLER_MODULE = importlib.import_module(module_name) + with ExitStack() as stack: + for target, value in patch_targets: + stack.enter_context(patch(target, value)) + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) module = _CONTROLLER_MODULE monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index f5ebe0b534..b2d13dbbdf 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -1,4 +1,3 @@ -from datetime import datetime from io import BytesIO from unittest.mock import MagicMock, patch @@ -26,6 +25,7 @@ from controllers.console.workspace.workspace import ( WorkspacePermissionApi, ) from enums.cloud_plan import CloudPlan +from libs.datetime_utils import naive_utc_now from models.account import TenantStatus @@ -44,13 +44,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) with ( @@ -97,13 +97,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features_t2 = MagicMock() @@ -152,13 +152,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() @@ -204,7 +204,7 @@ class TestTenantListApi: id="t1", name="Tenant", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() @@ -243,13 +243,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) with ( @@ -305,7 +305,7 @@ class TestWorkspaceListApi: api = WorkspaceListApi() method = unwrap(api.get) - tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now()) paginate_result = MagicMock( items=[tenant], @@ -331,7 +331,7 @@ class TestWorkspaceListApi: id="t1", name="T", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) paginate_result = MagicMock( diff --git a/api/tests/unit_tests/controllers/files/test_tool_files.py b/api/tests/unit_tests/controllers/files/test_tool_files.py index e5df7a1eea..edb91c3f26 100644 --- a/api/tests/unit_tests/controllers/files/test_tool_files.py +++ b/api/tests/unit_tests/controllers/files/test_tool_files.py @@ -18,10 +18,10 @@ def fake_request(args: dict): class DummyToolFile: - def __init__(self, mimetype="text/plain", size=10, name="tool.txt"): - self.mimetype = mimetype + def __init__(self, mime_type="text/plain", size=10, filename="tool.txt"): + self.mime_type = mime_type self.size = size - self.name = name + self.filename = filename @pytest.fixture(autouse=True) @@ -87,8 +87,8 @@ class TestToolFileApi: stream = iter([b"data"]) tool_file = DummyToolFile( - mimetype="application/pdf", - name="doc.pdf", + mime_type="application/pdf", + filename="doc.pdf", ) mock_tool_file_manager.return_value.get_file_generator_by_tool_file_id.return_value = ( diff --git a/api/tests/unit_tests/controllers/inner_api/app/__init__.py b/api/tests/unit_tests/controllers/inner_api/app/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py new file mode 100644 index 0000000000..5862239142 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -0,0 +1,245 @@ +"""Unit tests for inner_api app DSL import/export endpoints. + +Tests Pydantic model validation, endpoint handler logic, and the +_get_active_account helper. Auth/setup decorators are tested separately +in test_auth_wraps.py; handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.app.dsl import ( + EnterpriseAppDSLExport, + EnterpriseAppDSLImport, + InnerAppDSLImportPayload, + _get_active_account, +) +from services.app_dsl_service import ImportStatus + + +class TestInnerAppDSLImportPayload: + """Test InnerAppDSLImportPayload Pydantic model validation.""" + + def test_valid_payload_all_fields(self): + data = { + "yaml_content": "version: 0.6.0\nkind: app\n", + "creator_email": "user@example.com", + "name": "My App", + "description": "A test app", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.yaml_content == data["yaml_content"] + assert payload.creator_email == "user@example.com" + assert payload.name == "My App" + assert payload.description == "A test app" + + def test_valid_payload_optional_fields_omitted(self): + data = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.name is None + assert payload.description is None + + def test_missing_yaml_content_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"creator_email": "a@b.com"}) + assert "yaml_content" in str(exc_info.value) + + def test_missing_creator_email_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"yaml_content": "test"}) + assert "creator_email" in str(exc_info.value) + + +class TestGetActiveAccount: + """Test the _get_active_account helper function.""" + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_active_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "active" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + + result = _get_active_account("user@example.com") + + assert result is mock_account + mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com") + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_inactive_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "banned" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + + result = _get_active_account("banned@example.com") + + assert result is None + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_nonexistent_email(self, mock_db): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + result = _get_active_account("missing@example.com") + + assert result is None + + +class TestEnterpriseAppDSLImport: + """Test EnterpriseAppDSLImport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLImport() + + @pytest.fixture + def _mock_import_deps(self): + """Patch db, Session, and AppDslService for import handler tests.""" + with ( + patch("controllers.inner_api.app.dsl.db"), + patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, + ): + mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_session.return_value.__exit__ = MagicMock(return_value=False) + self._mock_dsl = MagicMock() + mock_dsl_cls.return_value = self._mock_dsl + yield + + def _make_import_result(self, status: ImportStatus, **kwargs) -> "Import": + from services.app_dsl_service import Import + + result = Import( + id="import-id", + status=status, + app_id=kwargs.get("app_id", "app-123"), + app_mode=kwargs.get("app_mode", "workflow"), + ) + return result + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_success_returns_200(self, mock_get_account, api_instance, app: Flask): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.COMPLETED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 200 + assert body["status"] == "completed" + mock_account.set_tenant_id.assert_called_once_with("ws-123") + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_pending_returns_202(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.PENDING) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 202 + assert body["status"] == "pending" + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_failed_returns_400(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.FAILED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 400 + assert body["status"] == "failed" + + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = None + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "missing@e.com"} + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 404 + assert "missing@e.com" in body["message"] + + +class TestEnterpriseAppDSLExport: + """Test EnterpriseAppDSLExport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLExport() + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + assert body["data"] == "version: 0.6.0\nkind: app\n" + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=False) + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "yaml-data" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=true"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=True) + + @patch("controllers.inner_api.app.dsl.db") + def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="nonexistent") + + body, status_code = result + assert status_code == 404 + assert "app not found" in body["message"] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f8e9cf9b80..1507bf7a5f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -65,7 +65,7 @@ class TestAppParameterApi: mock_tenant.status = "normal" # Mock DB queries for app and tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -112,7 +112,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -153,7 +153,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -192,7 +192,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -255,7 +255,7 @@ class TestAppMetaApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -323,7 +323,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -380,7 +380,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -426,7 +426,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -478,7 +478,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1923ab7fa7..e81e612803 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -29,7 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 4e4482f704..3364c07e62 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -34,7 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 1bdcd0f1a3..d83c22f2cf 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -79,10 +79,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -100,8 +103,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + # Mock MessageFile not found via scalar() + mock_db.session.scalar.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -115,8 +118,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile found but Message not owned by app - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock MessageFile found but Message not owned by app via scalar() + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found None, # Message query - not found (access denied) ] @@ -133,12 +136,13 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile and Message found but UploadFile not found - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found mock_message, # Message query - found - None, # UploadFile query - not found ] + # Mock get() for UploadFile - not found + mock_db.session.get.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -161,10 +165,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -262,10 +269,13 @@ class TestFilePreviewApi: mock_storage.load.return_value = mock_generator with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -301,10 +311,13 @@ class TestFilePreviewApi: mock_storage.load.side_effect = Exception("Storage error") with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries for validation - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -327,8 +340,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database query to raise unexpected exception - mock_db.session.query.side_effect = Exception("Unexpected database error") + # Mock database scalar to raise unexpected exception + mock_db.session.scalar.side_effect = Exception("Unexpected database error") # Execute and assert exception with pytest.raises(FileAccessDeniedError) as exc_info: diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 4eada73b82..6543c27037 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -35,7 +35,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -315,7 +315,7 @@ class TestWorkflowStopMechanism: def test_graph_engine_manager_has_send_stop_command(self): """Test GraphEngineManager has send_stop_command method.""" - from dify_graph.graph_engine.manager import GraphEngineManager + from graphon.graph_engine.manager import GraphEngineManager assert hasattr(GraphEngineManager, "send_stop_command") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 9e95f45a0a..eda270258d 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,7 +1,7 @@ from types import SimpleNamespace from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 4337a0c8c0..eddba5a517 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -12,6 +12,7 @@ from unittest.mock import Mock import pytest from flask import Flask +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -118,11 +119,8 @@ class AuthenticationMocker: @staticmethod def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None): - """Configure mock_db to return app and tenant in sequence.""" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - ] + """Configure mock_db to return app and tenant via session.get().""" + mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: mock_ta = Mock() @@ -135,11 +133,9 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_query = mock_db.session.query.return_value - target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value - target_mock.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + mock_db.session.get.return_value = mock_account @pytest.fixture @@ -175,7 +171,7 @@ def mock_document(): document.name = "test_document.txt" document.indexing_status = "completed" document.enabled = True - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX return document diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 7cb2f1050c..910d781cd0 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import ( from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.tag_service import TagService @@ -277,7 +278,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Test Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "0" # Required for Pydantic validation - must be string mock_tag_service.get_tags.return_value = [mock_tag] @@ -316,7 +317,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "new_tag_1" mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag_service.save_tags.return_value = mock_tag mock_service_api_ns.payload = {"name": "New Tag"} @@ -378,7 +379,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Updated Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "5" mock_tag_service.update_tags.return_value = mock_tag mock_tag_service.get_tag_binding_count.return_value = 5 @@ -866,7 +867,7 @@ class TestTagService: mock_tag = Mock() mock_tag.id = str(uuid.uuid4()) mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_save.return_value = mock_tag result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) @@ -941,11 +942,11 @@ class TestDatasetListApiGet: """Test suite for DatasetListApi.get() endpoint. ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``ProviderManager``, and ``marshal``. + ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. """ @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_list_datasets_success( @@ -1043,12 +1044,12 @@ class TestDatasetApiGet: """Test suite for DatasetApi.get() endpoint. ``get`` has no billing decorators but calls ``DatasetService``, - ``ProviderManager``, ``marshal``, and ``current_user``. + ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. """ @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_success( diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5c48ef1804..7f5d6b0839 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import ( SegmentCreatePayload, SegmentListQuery, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -787,8 +788,8 @@ class TestSegmentApiGet: """Test successful segment list retrieval.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc_svc.get_document.return_value = Mock(doc_form="text_model") + mock_db.session.scalar.return_value = mock_dataset + mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -812,7 +813,7 @@ class TestSegmentApiGet: """Test 404 when dataset not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -832,7 +833,7 @@ class TestSegmentApiGet: """Test 404 when document not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None # Act & Assert @@ -898,12 +899,12 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.segment_create_args_validate.return_value = None @@ -949,7 +950,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -991,7 +992,7 @@ class TestSegmentApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "indexing" # Not completed @@ -1042,7 +1043,7 @@ class TestDatasetSegmentApiDelete: """Test successful segment deletion.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock() @@ -1086,12 +1087,12 @@ class TestDatasetSegmentApiDelete: """Test 404 when segment not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = None # Segment not found @@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when dataset not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when document not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = mock_segment @@ -1279,7 +1280,7 @@ class TestDatasetSegmentApiUpdate: """Test 404 when dataset not found for update.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1320,7 +1321,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1369,9 +1370,9 @@ class TestDatasetSegmentApiGetSingle: ): """Test successful single segment retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None - mock_doc = Mock(doc_form="text_model") + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} @@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle: assert status == 200 assert "data" in response - assert response["doc_form"] == "text_model" + assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") @@ -1404,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1435,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1470,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1514,7 +1515,7 @@ class TestChildChunkApiGet: ): """Test successful child chunk list retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() @@ -1553,7 +1554,7 @@ class TestChildChunkApiGet: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1582,7 +1583,7 @@ class TestChildChunkApiGet: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None with app.test_request_context( @@ -1614,7 +1615,7 @@ class TestChildChunkApiGet: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1675,7 +1676,7 @@ class TestChildChunkApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() mock_child = Mock() @@ -1716,7 +1717,7 @@ class TestChildChunkApiPost: """Test 404 when dataset not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1754,7 +1755,7 @@ class TestChildChunkApiPost: """Test 404 when segment not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1807,7 +1808,7 @@ class TestDatasetChildChunkApiDelete: ): """Test successful child chunk deletion.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc_svc.get_document.return_value = mock_doc @@ -1857,7 +1858,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1898,7 +1899,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when segment does not belong to the document.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1938,7 +1939,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk does not belong to the segment.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index e6e841be19..12d5e7345d 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload: def test_payload_with_defaults(self): """Test payload default values.""" payload = DocumentTextCreatePayload(name="Doc", text="Content") - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" assert payload.process_rule is None assert payload.indexing_technique is None @@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload: payload = DocumentTextCreatePayload( name="Full Document", text="Complete document content here", - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, doc_language="Chinese", indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) assert payload.name == "Full Document" - assert payload.doc_form == "qa_model" + assert payload.doc_form == IndexStructureType.QA_INDEX assert payload.doc_language == "Chinese" assert payload.indexing_technique == "high_quality" assert payload.embedding_model == "text-embedding-ada-002" @@ -147,8 +148,8 @@ class TestDocumentTextUpdate: def test_payload_with_doc_form_update(self): """Test payload with doc_form update.""" - payload = DocumentTextUpdate(doc_form="qa_model") - assert payload.doc_form == "qa_model" + payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX) + assert payload.doc_form == IndexStructureType.QA_INDEX def test_payload_with_language_update(self): """Test payload with doc_language update.""" @@ -158,7 +159,7 @@ class TestDocumentTextUpdate: def test_payload_default_values(self): """Test payload default values.""" payload = DocumentTextUpdate() - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" @@ -272,14 +273,24 @@ class TestDocumentDocForm: def test_text_model_form(self): """Test text_model form.""" - doc_form = "text_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.PARAGRAPH_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms def test_qa_model_form(self): """Test qa_model form.""" - doc_form = "qa_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.QA_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms @@ -504,7 +515,7 @@ class TestDocumentApiGet: doc.name = "test_document.txt" doc.indexing_status = "completed" doc.enabled = True - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.doc_type = "book" doc.doc_metadata_details = {"source": "upload"} @@ -706,7 +717,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = False @@ -735,7 +746,7 @@ class TestDocumentApiDelete: document_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None @@ -756,7 +767,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = True @@ -777,7 +788,7 @@ class TestDocumentApiDelete: # Arrange dataset_id = str(uuid.uuid4()) document_id = str(uuid.uuid4()) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -798,7 +809,7 @@ class TestDocumentListApi: def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset): """Test successful document list retrieval.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_pagination = Mock() mock_pagination.items = [Mock(), Mock()] @@ -827,7 +838,7 @@ class TestDocumentListApi: def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset): """Test 404 when dataset not found.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -849,8 +860,6 @@ class TestDocumentIndexingStatusApi: """Test successful indexing status retrieval.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc = Mock() mock_doc.id = str(uuid.uuid4()) mock_doc.is_paused = False @@ -866,8 +875,8 @@ class TestDocumentIndexingStatusApi: mock_doc_svc.get_batch_documents.return_value = [mock_doc] - # Mock segment count queries - mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5 + # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count + mock_db.session.scalar.side_effect = [mock_dataset, 5, 5] mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"} # Act @@ -887,7 +896,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when dataset not found.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -904,7 +913,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when no documents found for batch.""" # Arrange batch_id = "batch_empty" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_batch_documents.return_value = [] # Act & Assert @@ -975,7 +984,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset.indexing_technique = "economy" mock_current_user.id = str(uuid.uuid4()) @@ -1024,7 +1033,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1053,7 +1062,7 @@ class TestDocumentAddByTextApi: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset # Act & Assert with app.test_request_context( @@ -1139,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost: _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = "economy" mock_dataset.latest_process_rule = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() @@ -1182,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None doc_id = str(uuid.uuid4()) with app.test_request_context( @@ -1221,7 +1230,7 @@ class TestDocumentAddByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1252,7 +1261,7 @@ class TestDocumentAddByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1287,7 +1296,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = "economy" mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset with app.test_request_context( f"/datasets/{mock_dataset.id}/document/create_by_file", @@ -1317,7 +1326,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = None mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1355,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1391,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1439,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost: mock_dataset.chunk_structure = None mock_dataset.latest_process_rule = Mock() mock_dataset.created_by_account = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py index b58caf3be1..c0b40d070a 100644 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ b/api/tests/unit_tests/controllers/service_api/test_site.py @@ -88,7 +88,7 @@ class TestAppSiteApi: mock_app_model.tenant = mock_tenant # Mock wraps.db for authentication - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -98,7 +98,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site.db for site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -109,7 +109,7 @@ class TestAppSiteApi: assert response["title"] == "Test Site" assert response["icon"] == "icon-url" assert response["description"] == "Site description" - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() @patch("controllers.service_api.wraps.user_logged_in") @patch("controllers.service_api.app.site.db") @@ -140,7 +140,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -150,7 +150,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query to return None - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -187,7 +187,7 @@ class TestAppSiteApi: mock_tenant = Mock() mock_tenant.status = TenantStatus.NORMAL - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -197,7 +197,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Set tenant status to archived AFTER authentication mock_app_model.tenant.status = TenantStatus.ARCHIVE @@ -230,7 +230,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -258,7 +258,7 @@ class TestAppSiteApi: mock_site.icon_type = "image" mock_site.created_at = "2024-01-01T00:00:00" mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -267,4 +267,4 @@ class TestAppSiteApi: # Assert # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 9c2d075f41..a2008e024b 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -144,14 +144,10 @@ class TestValidateAppToken: mock_ta = Mock() mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - mock_account, - ] + # Use side_effect to return app first, then tenant via session.get() + mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query + # Mock the tenant owner query (execute(select(...)).one_or_none()) setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) @validate_app_token @@ -175,7 +171,7 @@ class TestValidateAppToken: mock_api_token.app_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None @validate_app_token def protected_view(**kwargs): @@ -198,7 +194,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "abnormal" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -222,7 +218,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "normal" mock_app.enable_api = False - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -474,11 +470,11 @@ class TestValidateDatasetToken: mock_account.id = mock_ta.account_id mock_account.current_tenant = mock_tenant - # Mock the tenant account join query + # Mock the tenant account join query (execute(select(...)).one_or_none()) setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) - # Mock the account query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + # Mock the account lookup via session.get() + mock_db.session.get.return_value = mock_account @validate_dataset_token def protected_view(tenant_id): @@ -501,7 +497,7 @@ class TestValidateDatasetToken: mock_api_token.tenant_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None @validate_dataset_token def protected_view(dataset_id=None, **kwargs): diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index 01f34345aa..a6ca441801 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -21,7 +21,7 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index e88bcf2ae6..4f8d848637 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -18,7 +18,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index f6d1edbaf0..cde8820e00 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -6,7 +6,7 @@ import pytest from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): @@ -387,7 +387,7 @@ class TestRun: runner.update_prompt_message_tool.assert_called_once() def test_historic_with_assistant_and_tool_calls(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage assistant = AssistantPromptMessage(content="thinking") assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] @@ -400,7 +400,7 @@ class TestRun: assert isinstance(result, list) def test_historic_final_flush_branch(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage assistant = AssistantPromptMessage(content="final") runner.history_prompt_messages = [assistant] @@ -458,7 +458,7 @@ class TestFillInputsEdgeCases: class TestOrganizeHistoricPromptMessagesExtended: def test_user_message_flushes_scratchpad(self, runner, mocker): - from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + from graphon.model_runtime.entities.message_entities import UserPromptMessage user_message = UserPromptMessage(content="Hi") @@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended: assert result == ["final"] def test_tool_message_without_scratchpad_raises(self, runner): - from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + from graphon.model_runtime.entities.message_entities import ToolPromptMessage runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index f9d69d1196..ea8cc8aa86 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, @@ -93,7 +93,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", @@ -118,7 +118,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index ab822bb57d..2f5873d865 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -3,7 +3,7 @@ import json import pytest from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 299c9b31d2..17ab5babcb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -8,8 +8,8 @@ from core.agent.errors import AgentMaxIterationError from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, diff --git a/api/tests/unit_tests/core/app/app_config/__init__.py b/api/tests/unit_tests/core/app/app_config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py new file mode 100644 index 0000000000..1c5b6ed944 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py @@ -0,0 +1,227 @@ +from unittest.mock import MagicMock + +import pytest + +# Module under test +from core.app.app_config.common import parameters_mapping + + +class TestGetParametersFromFeatureDict: + """Test suite for get_parameters_from_feature_dict""" + + @pytest.fixture + def mock_config(self, monkeypatch): + """Mock dify_config values""" + mock = MagicMock() + mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1 + mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2 + mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3 + mock.UPLOAD_FILE_SIZE_LIMIT = 4 + mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5 + + monkeypatch.setattr(parameters_mapping, "dify_config", mock) + return mock + + @pytest.fixture + def mock_default_file_limits(self, monkeypatch): + """Mock DEFAULT_FILE_NUMBER_LIMITS constant""" + monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99) + return 99 + + @pytest.fixture + def minimal_inputs(self): + return {}, [] + + @pytest.mark.parametrize( + ("feature_key", "expected_default"), + [ + ("suggested_questions", []), + ("suggested_questions_after_answer", {"enabled": False}), + ("speech_to_text", {"enabled": False}), + ("text_to_speech", {"enabled": False}), + ("retriever_resource", {"enabled": False}), + ("annotation_reply", {"enabled": False}), + ("more_like_this", {"enabled": False}), + ( + "sensitive_word_avoidance", + {"enabled": False, "type": "", "configs": []}, + ), + ], + ) + def test_defaults_when_key_missing( + self, + feature_key, + expected_default, + mock_config, + mock_default_file_limits, + ): + # Arrange + features = {} + user_input = [] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + assert result[feature_key] == expected_default + + def test_opening_statement_present(self, mock_config, mock_default_file_limits): + # Arrange + features = {"opening_statement": "Hello"} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] == "Hello" + + def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] is None + + def test_all_features_provided(self, mock_config, mock_default_file_limits): + # Arrange + features = { + "opening_statement": "Hi", + "suggested_questions": ["Q1"], + "suggested_questions_after_answer": {"enabled": True}, + "speech_to_text": {"enabled": True}, + "text_to_speech": {"enabled": True}, + "retriever_resource": {"enabled": True}, + "annotation_reply": {"enabled": True}, + "more_like_this": {"enabled": True}, + "sensitive_word_avoidance": { + "enabled": True, + "type": "strict", + "configs": ["a"], + }, + "file_upload": { + "image": { + "enabled": True, + "number_limits": 10, + "detail": "low", + "transfer_methods": ["local_file"], + } + }, + } + user_input = [{"name": "field1"}] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + for key in features: + assert result[key] == features[key] + assert result["user_input_form"] == user_input + + def test_file_upload_default_structure(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + file_upload = result["file_upload"] + assert file_upload["image"]["enabled"] is False + assert file_upload["image"]["number_limits"] == 99 + assert file_upload["image"]["detail"] == "high" + assert "remote_url" in file_upload["image"]["transfer_methods"] + assert "local_file" in file_upload["image"]["transfer_methods"] + + def test_system_parameters_from_config(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + system_params = result["system_parameters"] + assert system_params["image_file_size_limit"] == 1 + assert system_params["video_file_size_limit"] == 2 + assert system_params["audio_file_size_limit"] == 3 + assert system_params["file_size_limit"] == 4 + assert system_params["workflow_file_upload_limit"] == 5 + + @pytest.mark.parametrize( + ("features_dict", "user_input_form"), + [ + (None, []), + ([], []), + ("invalid", []), + ], + ) + def test_invalid_features_dict_type_raises(self, features_dict, user_input_form): + # Act & Assert + with pytest.raises(AttributeError): + parameters_mapping.get_parameters_from_feature_dict( + features_dict=features_dict, + user_input_form=user_input_form, + ) + + @pytest.mark.parametrize( + "user_input_form", + [None, "invalid", 123], + ) + def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input_form, + ) + + # Assert + assert result["user_input_form"] == user_input_form + + def test_empty_user_input_form(self, mock_config, mock_default_file_limits): + features = {} + user_input = [] + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + assert result["user_input_form"] == [] + + def test_feature_values_none(self, mock_config, mock_default_file_limits): + features = { + "suggested_questions": None, + "speech_to_text": None, + } + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + assert result["suggested_questions"] is None + assert result["speech_to_text"] is None diff --git a/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py new file mode 100644 index 0000000000..013ed0cbc4 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py @@ -0,0 +1,202 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.common.sensitive_word_avoidance.manager import ( + SensitiveWordAvoidanceConfigManager, +) + + +class TestSensitiveWordAvoidanceConfigManagerConvert: + """Tests for convert classmethod""" + + @pytest.mark.parametrize( + "config", + [ + {}, + {"sensitive_word_avoidance": None}, + {"sensitive_word_avoidance": {}}, + {"sensitive_word_avoidance": {"enabled": False}}, + ], + ) + def test_convert_returns_none_when_disabled_or_missing(self, config): + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result is None + + def test_convert_returns_entity_when_enabled(self, mocker): + # Arrange + mock_entity = MagicMock() + mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"key": "value"}, + } + } + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result == mock_entity + + def test_convert_enabled_without_type_or_config(self, mocker): + # Arrange + mock_entity = MagicMock() + patched = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = {"sensitive_word_avoidance": {"enabled": True}} + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + patched.assert_called_once_with(type=None, config={}) + assert result == mock_entity + + +class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults: + """Tests for validate_and_set_defaults classmethod""" + + @pytest.fixture + def base_config(self): + return {} + + def test_validate_sets_default_when_missing(self, base_config): + # Act + config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=base_config.copy() + ) + + # Assert + assert config["sensitive_word_avoidance"]["enabled"] is False + assert fields == ["sensitive_word_avoidance"] + + def test_validate_raises_when_not_dict(self): + config = {"sensitive_word_avoidance": "invalid"} + + with pytest.raises(ValueError, match="must be of dict type"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + @pytest.mark.parametrize( + "config", + [ + {"sensitive_word_avoidance": {"enabled": False}}, + {"sensitive_word_avoidance": {"enabled": None}}, + {"sensitive_word_avoidance": {}}, + ], + ) + def test_validate_disables_when_enabled_false_or_missing(self, config): + # Act + result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + assert result_config["sensitive_word_avoidance"]["enabled"] is False + + def test_validate_raises_when_enabled_true_without_type(self): + config = {"sensitive_word_avoidance": {"enabled": True}} + + with pytest.raises(ValueError, match="type is required"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_type_not_string(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": 123, + } + } + + with pytest.raises(ValueError, match="must be a string"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_config_not_dict(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": "invalid", + } + } + + with pytest.raises(ValueError, match="must be a dict"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_calls_moderation_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"}) + assert result_config["sensitive_word_avoidance"]["enabled"] is True + assert fields == ["sensitive_word_avoidance"] + + def test_validate_sets_empty_dict_when_config_none(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": None, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={}) + + def test_validate_only_structure_validate_skips_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config, only_structure_validate=True + ) + + # Assert + mock_validate.assert_not_called() diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py new file mode 100644 index 0000000000..992b580376 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager + + +class TestAgentConfigManagerConvert: + @pytest.fixture + def base_config(self): + return { + "agent_mode": { + "enabled": True, + "strategy": "cot", + "tools": [], + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "completion", + }, + } + + def test_convert_returns_none_when_agent_mode_missing(self): + config = {"model": {"provider": "openai", "name": "gpt-4"}} + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}]) + def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config): + config = base_config.copy() + config["agent_mode"] = agent_mode_value + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize( + ("strategy_input", "expected_enum"), + [ + ("function_call", "FUNCTION_CALLING"), + ("cot", "CHAIN_OF_THOUGHT"), + ("react", "CHAIN_OF_THOUGHT"), + ], + ) + def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": strategy_input, + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.strategy.name == expected_enum + + def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "openai" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "FUNCTION_CALLING" + + def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "anthropic" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "CHAIN_OF_THOUGHT" + + def test_convert_skips_disabled_tools(self, mocker, base_config): + # Patch AgentEntity to bypass pydantic validation + mock_agent_entity = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity", + return_value=MagicMock(), + ) + + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value={ + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "tool_parameters": {}, + "credential_id": None, + }, + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + { + "provider_type": "type1", + "provider_id": "id1", + "tool_name": "tool1", + "enabled": False, + }, + { + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "enabled": True, + "extra_key": "x", + }, + ], + } + + AgentConfigManager.convert(config) + + mock_validate.assert_called_once() + mock_agent_entity.assert_called_once() + + def test_convert_tool_requires_minimum_keys(self, mocker, base_config): + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value=MagicMock(), + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + {"a": 1, "b": 2}, # insufficient keys + ], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.tools == [] + mock_validate.assert_not_called() + + def test_convert_completion_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "completion" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_chat_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "chat" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_react_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "react_router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_max_iteration_default(self, base_config): + config = base_config.copy() + config["agent_mode"].pop("max_iteration", None) + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 10 + + def test_convert_custom_max_iteration(self, base_config): + config = base_config.copy() + config["agent_mode"]["max_iteration"] = 25 + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 25 + + def test_convert_missing_model_raises_key_error(self, base_config): + config = base_config.copy() + del config["model"] + + with pytest.raises(KeyError): + AgentConfigManager.convert(config) + + @pytest.mark.parametrize( + ("invalid_config", "should_raise"), + [ + (None, True), + (123, True), + ("", False), + ([], False), + ], + ) + def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise): + if should_raise: + with pytest.raises(TypeError): + AgentConfigManager.convert(invalid_config) # type: ignore + else: + result = AgentConfigManager.convert(invalid_config) # type: ignore + assert result is None diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py new file mode 100644 index 0000000000..a688e2a5c5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py @@ -0,0 +1,319 @@ +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def valid_uuid(): + return str(uuid.uuid4()) + + +@pytest.fixture +def base_config(valid_uuid): + return { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + + +@pytest.fixture +def mock_dataset_service(mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + +# ============================== +# convert tests +# ============================== + + +class TestDatasetConfigManagerConvert: + def test_convert_returns_none_when_no_datasets(self): + config = {"dataset_configs": {"datasets": {"datasets": []}}} + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_single_retrieval(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result is not None + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable == "query" + + def test_convert_single_with_metadata_configs(self, valid_uuid, mocker): + mock_retrieve_config = MagicMock() + mock_entity = MagicMock() + mock_entity.dataset_ids = [valid_uuid] + mock_entity.retrieve_config = mock_retrieve_config + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig", + return_value={"mock": "model"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition", + return_value={"mock": "condition"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity", + return_value=mock_retrieve_config, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity", + return_value=mock_entity, + ) + + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "metadata_filtering_mode": "manual", + "metadata_model_config": {"any": "value"}, + "metadata_filtering_conditions": {"any": "value"}, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config is mock_retrieve_config + + def test_convert_multiple_defaults(self, valid_uuid): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 4 + assert result.retrieve_config.score_threshold is None + assert result.retrieve_config.reranking_enabled is True + + def test_convert_agent_mode_disabled_tool(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": False}}], + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_dataset_configs_none(self): + config = {"dataset_configs": None} + with pytest.raises(TypeError): + DatasetConfigManager.convert(config) + + def test_convert_agent_mode_old_style_old_format(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable is None + + def test_convert_multiple_with_score_threshold(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "multiple", + "top_k": 5, + "score_threshold": 0.8, + "score_threshold_enabled": True, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 5 + assert result.retrieve_config.score_threshold == 0.8 + + @pytest.mark.parametrize( + "dataset_entry", + [ + {}, + {"invalid": {}}, + {"dataset": {"id": None, "enabled": True}}, + {"dataset": {"id": "", "enabled": False}}, + ], + ) + def test_convert_ignores_invalid_dataset_entries(self, dataset_entry): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": {"strategy": "router", "datasets": [dataset_entry]}, + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_agent_mode_old_style(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + + +# ============================== +# validate_and_set_defaults tests +# ============================== + + +class TestValidateAndSetDefaults: + def test_validate_sets_defaults(self): + config = {} + updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + assert "dataset_configs" in updated + assert updated["dataset_configs"]["retrieval_model"] == "single" + assert isinstance(fields, list) + + def test_validate_raises_when_dataset_configs_not_dict(self): + config = {"dataset_configs": "invalid"} + with pytest.raises(AttributeError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + + def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid): + config = { + "dataset_configs": { + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + } + with pytest.raises(ValueError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config) + + +# ============================== +# extract_dataset_config_for_legacy_compatibility tests +# ============================== + + +class TestExtractDatasetConfig: + def test_extract_sets_defaults(self): + config = {} + result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + assert "agent_mode" in result + assert result["agent_mode"]["enabled"] is False + assert result["agent_mode"]["tools"] == [] + + def test_extract_invalid_agent_mode_type(self): + config = {"agent_mode": "invalid"} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_enabled_type(self): + config = {"agent_mode": {"enabled": "yes"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_tools_type(self): + config = {"agent_mode": {"enabled": True, "tools": "invalid"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_uuid(self, mocker): + invalid_uuid = "not-a-uuid" + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_dataset_not_exists(self, valid_uuid, mocker): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + +# ============================== +# is_dataset_exists tests +# ============================== + + +class TestIsDatasetExists: + def test_dataset_exists_true(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "other" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py new file mode 100644 index 0000000000..186b4a501d --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.entities.model_entities import ModelStatus +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + + +class TestModelConfigConverter: + @pytest.fixture(autouse=True) + def patch_response_entity(self, mocker): + """ + Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation + and return a simple namespace object instead. + """ + + def _factory(**kwargs): + return SimpleNamespace(**kwargs) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity", + side_effect=_factory, + ) + + @pytest.fixture + def mock_app_config(self): + app_config = MagicMock() + app_config.tenant_id = "tenant_1" + + model_config = MagicMock() + model_config.provider = "openai" + model_config.model = "gpt-4" + model_config.parameters = {"temperature": 0.5} + model_config.mode = None + + app_config.model = model_config + return app_config + + @pytest.fixture + def mock_provider_bundle(self): + bundle = MagicMock() + + # configuration + configuration = MagicMock() + configuration.provider.provider = "openai" + configuration.get_current_credentials.return_value = {"api_key": "key"} + + provider_model = MagicMock() + provider_model.status = ModelStatus.ACTIVE + configuration.get_provider_model.return_value = provider_model + + bundle.configuration = configuration + + # model type instance + model_type_instance = MagicMock() + model_schema = MagicMock() + model_schema.model_properties = {} + model_type_instance.get_model_schema.return_value = model_schema + bundle.model_type_instance = model_type_instance + + return bundle + + @pytest.fixture + def patch_provider_manager(self, mocker, mock_provider_bundle): + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + return mock_manager + + # ============================= + # Positive Scenarios + # ============================= + + def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager): + result = ModelConfigConverter.convert(mock_app_config) + + assert result.provider == "openai" + assert result.model == "gpt-4" + assert result.mode == LLMMode.CHAT + assert result.parameters == {"temperature": 0.5} + assert result.stop == [] + + def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager): + mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]} + + result = ModelConfigConverter.convert(mock_app_config) + + assert result.parameters == {"temperature": 0.7} + assert result.stop == ["\n"] + + def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker): + mock_app_config.model.mode = None + + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: LLMMode.COMPLETION.value + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.COMPLETION + + def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: "invalid" + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.CHAT + + # ============================= + # Credential Errors + # ============================= + + def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_current_credentials.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ProviderTokenNotInitError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Provider Model Errors + # ============================= + + def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_provider_model.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError), + (ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError), + (ModelStatus.QUOTA_EXCEEDED, QuotaExceededError), + ], + ) + def test_convert_provider_model_status_errors( + self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception + ): + mock_provider = MagicMock() + mock_provider.status = status + mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(expected_exception): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Schema Errors + # ============================= + + def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Edge Cases + # ============================= + + @pytest.mark.parametrize( + "parameters", + [ + {}, + {"stop": []}, + {"stop": ["END"], "max_tokens": 100}, + ], + ) + def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters): + mock_app_config.model.parameters = parameters.copy() + + result = ModelConfigConverter.convert(mock_app_config) + + if "stop" in parameters: + assert result.stop == parameters.get("stop") + expected_params = parameters.copy() + expected_params.pop("stop", None) + assert result.parameters == expected_params + else: + assert result.stop == [] + assert result.parameters == parameters diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py new file mode 100644 index 0000000000..68bca485bb --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +import pytest + +# Target +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def valid_completion_params(): + return {"temperature": 0.7, "stop": ["\n"]} + + +@pytest.fixture +def valid_model_list(): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {"mode": "chat"} + return [model] + + +@pytest.fixture +def provider_entities(): + provider = MagicMock() + provider.provider = "openai/gpt" + return [provider] + + +@pytest.fixture +def valid_config(): + return { + "model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}} + } + + +# ----------------------------- +# Test Class +# ----------------------------- + + +class TestModelConfigManager: + @staticmethod + def _patch_model_assembly(mocker, *, provider_entities, model_list): + assembly = MagicMock() + assembly.model_provider_factory.get_providers.return_value = provider_entities + assembly.provider_manager.get_configurations.return_value.get_models.return_value = model_list + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) + return assembly + + # ========================================================== + # convert + # ========================================================== + + def test_convert_success(self, valid_config): + result = ModelConfigManager.convert(valid_config) + + assert result.provider == "openai/gpt" + assert result.model == "gpt-4" + assert result.parameters == {"temperature": 0.5} + assert result.stop == ["END"] + + def test_convert_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.convert({}) + + def test_convert_without_stop(self): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}} + result = ModelConfigManager.convert(config) + assert result.stop == [] + assert result.parameters == {"temperature": 0.9} + + # ========================================================== + # validate_model_completion_params + # ========================================================== + + @pytest.mark.parametrize( + "invalid_cp", + [None, "string", 123, []], + ) + def test_validate_model_completion_params_invalid_type(self, invalid_cp): + with pytest.raises(ValueError, match="must be of object type"): + ModelConfigManager.validate_model_completion_params(invalid_cp) + + def test_validate_model_completion_params_default_stop(self): + cp = {"temperature": 0.2} + result = ModelConfigManager.validate_model_completion_params(cp) + assert result["stop"] == [] + + def test_validate_model_completion_params_invalid_stop_type(self): + cp = {"stop": "invalid"} + with pytest.raises(ValueError, match="must be of list type"): + ModelConfigManager.validate_model_completion_params(cp) + + def test_validate_model_completion_params_stop_length_exceeded(self): + cp = {"stop": [1, 2, 3, 4, 5]} + with pytest.raises(ValueError, match="less than 4"): + ModelConfigManager.validate_model_completion_params(cp) + + # ========================================================== + # validate_and_set_defaults + # ========================================================== + + def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) + + assert updated_config["model"]["mode"] == "chat" + assert keys == ["model"] + + def test_validate_and_set_defaults_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", {}) + + def test_validate_and_set_defaults_model_not_dict(self): + with pytest.raises(ValueError, match="object type"): + ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"}) + + def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): + config = {"model": {"name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): + config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.name is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {} + + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[model]) + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + assert updated_config["model"]["mode"] == "completion" + + def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + with pytest.raises(ValueError, match="completion_params is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list): + """ + Covers branch where provider does not contain '/' and + ModelProviderID conversion is triggered (line 64). + """ + config = { + "model": { + "provider": "openai", # no slash -> triggers conversion + "name": "gpt-4", + "completion_params": {}, + } + } + + # Mock ModelProviderID to return formatted provider + mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") + mock_provider_id.return_value = "openai/gpt" + provider_entity = MagicMock() + provider_entity.provider = "openai/gpt" + self._patch_model_assembly(mocker, provider_entities=[provider_entity], model_list=valid_model_list) + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + # Ensure conversion happened + mock_provider_id.assert_called_once_with("openai") + assert updated_config["model"]["provider"] == "openai/gpt" diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py new file mode 100644 index 0000000000..fd49072cd5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( + PromptTemplateConfigManager, +) + +# ----------------------------- +# Helpers +# ----------------------------- + + +class DummyEnumValue: + def __init__(self, value): + self.value = value + + +class DummyPromptType: + def __init__(self): + self.SIMPLE = "simple" + self.ADVANCED = "advanced" + + def value_of(self, value): + return value + + def __iter__(self): + return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + + +# ----------------------------- +# Convert Tests +# ----------------------------- + + +class TestPromptTemplateConfigManagerConvert: + def test_convert_missing_prompt_type_raises(self): + with pytest.raises(ValueError, match="prompt_type is required"): + PromptTemplateConfigManager.convert({}) + + def test_convert_simple_prompt(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mock_prompt_entity_cls.return_value = "simple_entity" + + config = {"prompt_type": "simple", "pre_prompt": "hello"} + + result = PromptTemplateConfigManager.convert(config) + + assert result == "simple_entity" + mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello") + + def test_convert_advanced_chat_valid(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of", + return_value="role_enum", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity", + return_value="chat_msg", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity", + return_value="chat_template", + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]}, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + @pytest.mark.parametrize( + "message", + [ + {"text": 123, "role": "user"}, + {"text": "hi", "role": 123}, + ], + ) + def test_convert_advanced_invalid_message_fields(self, mocker, message): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [message]}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.convert(config) + + def test_convert_advanced_completion_with_roles(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity", + return_value="completion_template", + ) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "complete"}, + "conversation_histories_role": { + "user_prefix": "U", + "assistant_prefix": "A", + }, + }, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + +# ----------------------------- +# validate_and_set_defaults +# ----------------------------- + + +class TestValidateAndSetDefaults: + def setup_method(self): + self.valid_model = {"mode": "chat"} + + def _patch_prompt_type(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + return mock_prompt_entity_cls + + def test_default_prompt_type_set(self, mocker): + self._patch_prompt_type(mocker) + + config = {"model": self.valid_model} + + result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + assert result["prompt_type"] == "simple" + assert isinstance(keys, list) + + def test_invalid_prompt_type_raises(self, mocker): + class InvalidEnum(DummyPromptType): + def __iter__(self): + return iter([DummyEnumValue("valid")]) + + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = InvalidEnum() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = {"prompt_type": "invalid", "model": self.valid_model} + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_invalid_chat_prompt_config_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "chat_prompt_config": "invalid", + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_simple_mode_invalid_pre_prompt_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "pre_prompt": 123, + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_requires_one_config(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {}, + "completion_prompt_config": {}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_invalid_model_mode(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": []}, + "model": {"mode": "invalid"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_chat_prompt_length_exceeds(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{}] * 11}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_completion_prefix_defaults_set_when_empty(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "hi"}, + "conversation_histories_role": { + "user_prefix": "", + "assistant_prefix": "", + }, + }, + "model": {"mode": "completion"}, + } + + updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config) + + roles = updated["completion_prompt_config"]["conversation_histories_role"] + assert roles["user_prefix"] == "Human" + assert roles["assistant_prefix"] == "Assistant" + + +# ----------------------------- +# validate_post_prompt +# ----------------------------- + + +class TestValidatePostPrompt: + @pytest.mark.parametrize("value", [None, ""]) + def test_post_prompt_defaults(self, value): + config = {"post_prompt": value} + result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) + assert result["post_prompt"] == "" + + def test_post_prompt_invalid_type(self): + config = {"post_prompt": 123} + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py new file mode 100644 index 0000000000..d9fe7004ff --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -0,0 +1,286 @@ +import pytest + +from core.app.app_config.easy_ui_based_app.variables.manager import ( + BasicVariablesConfigManager, +) +from graphon.variables.input_entities import VariableEntityType + + +class TestBasicVariablesConfigManagerConvert: + def test_convert_empty_config(self): + config = {} + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + def test_convert_external_data_tools_enabled_and_disabled(self, mocker): + config = { + "external_data_tools": [ + {"enabled": False}, + { + "enabled": True, + "variable": "ext_var", + "type": "tool_type", + "config": {"k": "v"}, + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert len(external) == 1 + assert external[0].variable == "ext_var" + assert external[0].type == "tool_type" + + def test_convert_user_input_form_variable_types(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "variable": "name", + "label": "Name", + "description": "desc", + "required": True, + "max_length": 50, + } + }, + { + VariableEntityType.SELECT: { + "variable": "choice", + "label": "Choice", + "options": ["a", "b"], + } + }, + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + "config": {"x": 1}, + } + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert len(variables) == 2 + assert len(external) == 1 + + def test_convert_external_data_tool_without_config_skipped(self): + config = { + "user_input_form": [ + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + } + } + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + +class TestValidateVariablesAndSetDefaults: + def test_validate_sets_empty_user_input_form_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"] == [] + assert "user_input_form" in keys + + def test_validate_user_input_form_not_list_raises(self): + config = {"user_input_form": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_invalid_key_raises(self): + config = {"user_input_form": [{"invalid": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_label_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_label_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_variable_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_variable_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + @pytest.mark.parametrize( + "variable_name", + ["1invalid", "invalid space", "", None], + ) + def test_validate_variable_invalid_pattern_raises(self, variable_name): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": variable_name, + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_required_default_and_type(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + } + } + ] + } + + updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False + + def test_validate_required_not_bool_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + "required": "yes", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_default_not_in_options_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": ["a", "b"], + "default": "c", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_not_list_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": "not_list", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + +class TestValidateExternalDataToolsAndSetDefaults: + def test_validate_sets_empty_external_data_tools_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + assert updated["external_data_tools"] == [] + assert "external_data_tools" in keys + + def test_validate_external_data_tools_not_list_raises(self): + config = {"external_data_tools": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_disabled_tool_skipped(self, mocker): + config = {"external_data_tools": [{"enabled": False}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + spy.assert_not_called() + assert updated["external_data_tools"][0]["enabled"] is False + + def test_validate_enabled_tool_missing_type_raises(self): + config = {"external_data_tools": [{"enabled": True, "config": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_enabled_tool_calls_factory(self, mocker): + config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config) + + spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1}) + + +class TestValidateAndSetDefaultsIntegration: + def test_validate_and_set_defaults_calls_both(self, mocker): + config = {} + + spy_var = mocker.patch.object( + BasicVariablesConfigManager, + "validate_variables_and_set_defaults", + return_value=(config, ["user_input_form"]), + ) + spy_ext = mocker.patch.object( + BasicVariablesConfigManager, + "validate_external_data_tools_and_set_defaults", + return_value=(config, ["external_data_tools"]), + ) + + updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config) + + spy_var.assert_called_once() + spy_ext.assert_called_once() + assert "user_input_form" in keys + assert "external_data_tools" in keys + assert updated == config diff --git a/api/tests/unit_tests/core/app/app_config/features/__init__.py b/api/tests/unit_tests/core/app/app_config/features/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index de99833aac..11fc15c94d 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,6 +1,6 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from dify_graph.file.models import FileTransferMethod, FileUploadConfig, ImageConfig -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent def test_convert_with_vision(): diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py new file mode 100644 index 0000000000..dd00c3defc --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -0,0 +1,115 @@ +import pytest + +from core.app.app_config.entities import TextToSpeechEntity +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager + + +class TestAdditionalFeatureManagers: + def test_opening_statement_validate_defaults(self): + config, keys = OpeningStatementConfigManager.validate_and_set_defaults({}) + assert config["opening_statement"] == "" + assert config["suggested_questions"] == [] + assert set(keys) == {"opening_statement", "suggested_questions"} + + def test_opening_statement_validate_types(self): + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123}) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": "bad"} + ) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": [1]} + ) + + def test_opening_statement_convert(self): + opening, questions = OpeningStatementConfigManager.convert( + {"opening_statement": "hello", "suggested_questions": ["q1"]} + ) + assert opening == "hello" + assert questions == ["q1"] + + def test_retrieval_resource_validate(self): + config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({}) + assert config["retriever_resource"]["enabled"] is False + assert keys == ["retriever_resource"] + + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"}) + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}}) + + def test_retrieval_resource_convert(self): + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False + + def test_speech_to_text_validate_and_convert(self): + config, keys = SpeechToTextConfigManager.validate_and_set_defaults({}) + assert config["speech_to_text"]["enabled"] is False + assert keys == ["speech_to_text"] + + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"}) + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}}) + + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False + + def test_suggested_questions_after_answer_validate_and_convert(self): + config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({}) + assert config["suggested_questions_after_answer"]["enabled"] is False + assert keys == ["suggested_questions_after_answer"] + + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": "bad"} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": "yes"}} + ) + + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) + is True + ) + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}}) + is False + ) + + def test_text_to_speech_validate_and_convert(self): + config, keys = TextToSpeechConfigManager.validate_and_set_defaults({}) + assert config["text_to_speech"]["enabled"] is False + assert keys == ["text_to_speech"] + + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"}) + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}}) + + result = TextToSpeechConfigManager.convert( + {"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}} + ) + assert isinstance(result, TextToSpeechEntity) + assert result.voice == "v" + assert result.language == "en" + + def test_more_like_this_convert_and_validate(self): + config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({}) + assert config["more_like_this"]["enabled"] is False + assert keys == ["more_like_this"] + + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False + with pytest.raises(ValueError): + MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"}) diff --git a/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py new file mode 100644 index 0000000000..e99852cf76 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py @@ -0,0 +1,180 @@ +from collections import UserDict +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager + + +class TestBaseAppConfigManager: + @pytest.fixture + def mock_config_dict(self): + return {"key": "value", "another": 123} + + @pytest.fixture + def mock_app_additional_features(self, mocker): + mock_instance = MagicMock() + mocker.patch( + "core.app.app_config.base_app_config_manager.AppAdditionalFeatures", + return_value=mock_instance, + ) + return mock_instance + + @pytest.fixture + def mock_managers(self, mocker): + retrieval = mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + return_value="retrieval_result", + ) + file_upload = mocker.patch( + "core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert", + return_value="file_upload_result", + ) + opening_statement = mocker.patch( + "core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert", + return_value=("opening_result", "suggested_result"), + ) + suggested_after = mocker.patch( + "core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert", + return_value="suggested_after_result", + ) + more_like_this = mocker.patch( + "core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert", + return_value="more_like_this_result", + ) + speech_to_text = mocker.patch( + "core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert", + return_value="speech_to_text_result", + ) + text_to_speech = mocker.patch( + "core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert", + return_value="text_to_speech_result", + ) + + return { + "retrieval": retrieval, + "file_upload": file_upload, + "opening_statement": opening_statement, + "suggested_after": suggested_after, + "more_like_this": more_like_this, + "speech_to_text": speech_to_text, + "text_to_speech": text_to_speech, + } + + @pytest.mark.parametrize( + ("app_mode", "expected_is_vision"), + [ + ("CHAT", True), + ("COMPLETION", True), + ("AGENT_CHAT", True), + ("OTHER", False), + ], + ) + def test_convert_features_all_modes( + self, + mocker, + mock_config_dict, + mock_app_additional_features, + mock_managers, + app_mode, + expected_is_vision, + ): + # Arrange + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode) + + # Assert + assert result == mock_app_additional_features + mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["file_upload"].assert_called_once() + _, kwargs = mock_managers["file_upload"].call_args + assert kwargs["config"] == dict(mock_config_dict.items()) + assert kwargs["is_vision"] is expected_is_vision + + mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items())) + + def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + empty_config = {} + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(empty_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called + + @pytest.mark.parametrize( + "invalid_config", + [ + None, + "string", + 123, + 12.34, + [], + ], + ) + def test_convert_features_invalid_config_raises(self, invalid_config): + # Act & Assert + with pytest.raises((TypeError, AttributeError)): + BaseAppConfigManager.convert_features(invalid_config, "CHAT") + + def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict): + # Arrange + mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + side_effect=RuntimeError("manager failure"), + ) + + # Act & Assert + with pytest.raises(RuntimeError): + BaseAppConfigManager.convert_features(mock_config_dict, "CHAT") + + def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + class CustomMapping(UserDict): + pass + + custom_config = CustomMapping({"a": 1}) + + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(custom_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py new file mode 100644 index 0000000000..f2bc3076da --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -0,0 +1,43 @@ +import pytest + +from core.app.app_config.entities import ( + DatasetRetrieveConfigEntity, + PromptTemplateEntity, +) +from graphon.variables.input_entities import VariableEntity, VariableEntityType + + +class TestAppConfigEntities: + def test_variable_entity_coerces_none_description_and_options(self): + entity = VariableEntity( + variable="query", + label="Query", + description=None, + type=VariableEntityType.TEXT_INPUT, + options=None, + ) + + assert entity.description == "" + assert entity.options == [] + + def test_variable_entity_rejects_invalid_json_schema(self): + with pytest.raises(ValueError): + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + json_schema={"type": "string", "minLength": "bad"}, + ) + + def test_prompt_template_value_of(self): + assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE + with pytest.raises(ValueError): + PromptTemplateEntity.PromptType.value_of("missing") + + def test_dataset_retrieve_strategy_value_of(self): + assert ( + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single") + == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + ) + with pytest.raises(ValueError): + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing") diff --git a/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py new file mode 100644 index 0000000000..fa128aca87 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py @@ -0,0 +1,222 @@ +import pytest + +from core.app.app_config.workflow_ui_based_app.variables.manager import ( + WorkflowVariablesConfigManager, +) + +# ============================= +# Fixtures +# ============================= + + +@pytest.fixture +def mock_workflow(mocker): + workflow = mocker.MagicMock() + workflow.graph_dict = {"nodes": []} + return workflow + + +@pytest.fixture +def mock_variable_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity") + + +@pytest.fixture +def mock_rag_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity") + + +# ============================= +# Test Convert (user_input_form) +# ============================= + + +class TestWorkflowVariablesConfigManagerConvert: + def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity): + # Arrange + input_variables = [{"name": "var1"}, {"name": "var2"}] + mock_workflow.user_input_form.return_value = input_variables + mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [{"validated": v} for v in input_variables] + assert mock_variable_entity.model_validate.call_count == 2 + + def test_convert_empty_list(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [] + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [] + mock_variable_entity.model_validate.assert_not_called() + + def test_convert_none_returned_raises(self, mock_workflow): + # Arrange + mock_workflow.user_input_form.return_value = None + + # Act & Assert + with pytest.raises(TypeError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [{"invalid": "data"}] + mock_variable_entity.model_validate.side_effect = ValueError("validation error") + + # Act & Assert + with pytest.raises(ValueError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + +# ============================= +# Test convert_rag_pipeline_variable +# ============================= + + +class TestWorkflowVariablesConfigManagerConvertRag: + def test_no_rag_pipeline_variables(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = [] + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_rag_pipeline_none(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = None + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}] + + def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + @pytest.mark.parametrize( + ("belong_to_node_id", "expected_count"), + [ + ("node1", 1), + ("shared", 1), + ("other_node", 0), + ], + ) + def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": belong_to_node_id}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == expected_count + + def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + + def test_validation_error_propagates(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed") + + # Act & Assert + with pytest.raises(RuntimeError): + WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 441d2fcd17..af5d203f12 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError from constants import UUID_NIL from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.advanced_chat.generate_task_pipeline import ( + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager +from libs.datetime_utils import naive_utc_now +from models.enums import MessageStatus from models.model import AppMode @@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={}) conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-1") + message = SimpleNamespace( + id="msg-1", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) captured: dict[str, object] = {} thread_data: dict[str, object] = {} @@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals: pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") response = generator._generate( - workflow=SimpleNamespace(features={"feature": True}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals: db_session.refresh.assert_called_once_with(conversation) db_session.close.assert_called_once() assert captured["draft_var_saver_factory"] == "draft-factory" + assert isinstance(captured["workflow"], WorkflowSnapshot) + assert isinstance(captured["conversation"], ConversationSnapshot) + assert isinstance(captured["message"], MessageSnapshot) def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): generator = AdvancedChatAppGenerator() @@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={}) conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-2") + message = SimpleNamespace( + id="msg-2", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) init_records = MagicMock() thread_data: dict[str, object] = {} @@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals: ) response = generator._generate( - workflow=SimpleNamespace(features={}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(GenerateTaskStoppedError): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(ValueError, match="other error"): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals: logger_exception.assert_called_once() - def test_refresh_model_returns_detached_model(self, monkeypatch): - source_model = SimpleNamespace(id="source-id") - detached_model = SimpleNamespace(id="source-id", detached=True) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def get(self, model_type, model_id): - _ = model_type - return detached_model if model_id == "source-id" else None - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) - - refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) - - assert refreshed is detached_model - def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): generator = AdvancedChatAppGenerator() generator._dialogue_count = 1 @@ -1053,7 +1038,7 @@ class TestAdvancedChatAppGeneratorInternals: _ = kwargs def run(self): - from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + from graphon.model_runtime.errors.invoke import InvokeAuthorizationError raise InvokeAuthorizationError("bad key") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 15aceef2c7..ef7df5e1da 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -7,10 +7,23 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from dify_graph.variables import SegmentType from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -49,7 +62,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variable (only var1 exists in DB) @@ -200,7 +213,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Mock conversation and message @@ -349,7 +362,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variables (both exist in DB) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5792a2f1e2..079df0b4e6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -8,6 +8,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationError +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + @pytest.fixture def build_runner(): @@ -30,7 +43,7 @@ def build_runner(): mock_workflow.tenant_id = str(uuid4()) mock_workflow.app_id = app_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] mock_app_config = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index 5b199e0c52..f2df35d7d0 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -10,7 +10,7 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 83a6e0f231..99a386cd45 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -17,11 +17,11 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent -from models.model import EndUser +from models.model import AppMode, EndUser def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: @@ -137,7 +137,6 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No actions=[], node_id="node-1", node_title="Approval", - form_token="token-1", resolved_default_values={}, ) event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) @@ -160,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None: task_id="task-1", ) queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( + conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT) + message = pipeline_module.MessageSnapshot( id="message-1", created_at=datetime(2024, 1, 1), query="hello", @@ -171,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None: user = EndUser() user.id = "user-1" user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={}) pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -185,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None: draft_var_saver_factory=SimpleNamespace(), ) - pipeline._get_message = mock.Mock(return_value=message) + stored_message = SimpleNamespace( + id="message-1", + answer="before", + status=MessageStatus.PAUSED, + updated_at=None, + provider_response_latency=0, + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + total_price=0, + currency="USD", + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="user-1", + ) + pipeline._get_message = mock.Mock(return_value=stored_message) pipeline._recorded_files = [] list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) pipeline._save_message(session=mock.Mock()) - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL + assert stored_message.answer == "beforeafter" + assert stored_message.status == MessageStatus.NORMAL def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 0a244b3fea..29fd63c063 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -1,13 +1,17 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -42,11 +46,13 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool +from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -72,15 +78,15 @@ def _make_pipeline(): workflow_run_id="run-id", ) - message = SimpleNamespace( + message = MessageSnapshot( id="message-id", query="hello", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), status=MessageStatus.NORMAL, answer="", ) - conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) - workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={}) user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") pipeline = AdvancedChatAppGenerateTaskPipeline( @@ -166,7 +172,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -256,7 +262,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -272,7 +278,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -280,7 +286,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -296,7 +302,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) @@ -311,7 +317,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -359,7 +365,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -369,7 +375,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -472,7 +478,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] @@ -522,7 +528,7 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -556,7 +562,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" @@ -590,7 +596,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 53f26d1592..80f7f94b1a 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -6,7 +6,7 @@ from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 5603115b30..4567b35480 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -3,8 +3,8 @@ import pytest from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 3cdffbb4cd..8f3c41701b 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 67b3777c40..f56ca8de99 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -9,8 +9,8 @@ from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.file.enums import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index b0789bbc1e..d6f7a05cdc 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -3,13 +3,15 @@ from types import SimpleNamespace import pytest from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport -from dify_graph.runtime import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, build_system_variables(workflow_execution_id=workflow_run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 72430a3347..3ab63aed25 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from dify_graph.variables.segments import ArrayFileSegment, FileSegment +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: @@ -12,7 +12,6 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Create a test File object""" return File( id=file_id, - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", @@ -223,7 +222,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: assert len(result) == 1 file_dict = result[0] assert file_dict["id"] == "property_test" - assert file_dict["tenant_id"] == "test_tenant" + assert "tenant_id" not in file_dict assert file_dict["type"] == "document" assert file_dict["transfer_method"] == "local_file" assert file_dict["filename"] == "property_test.txt" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 4ed7d73cd0..e8946281ac 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -4,13 +4,13 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 5879e8fb9b..492e11ee0f 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 374af5ddc4..7ee375d884 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -24,9 +24,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -54,7 +54,7 @@ class TestWorkflowResponseConverter: mock_user.name = "Test User" mock_user.email = "test@example.com" - system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + system_variables = build_system_variables(workflow_id="wf-id", workflow_execution_id="initial-run-id") return WorkflowResponseConverter( application_generate_entity=mock_entity, user=mock_user, @@ -451,9 +451,9 @@ class TestWorkflowResponseConverterServiceApiTruncation: account.id = "test_user_id" return account - def create_test_system_variables(self) -> SystemVariable: + def create_test_system_variables(self): """Create test system variables.""" - return SystemVariable() + return build_system_variables() def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter: """Create WorkflowResponseConverter with specified invoke_from.""" diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 51f33bac35..aa2085177e 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -6,7 +6,7 @@ import pytest import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 2714757353..f2e35f9900 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -9,7 +9,7 @@ import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 94ed8166b9..cfe797aa76 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -10,7 +10,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 72f7552bd1..9db83f5531 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -13,7 +13,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index eec95b7f39..fb19d6d761 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -26,7 +26,7 @@ import pytest import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.graph_events import GraphRunFailedEvent +from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: @@ -284,7 +284,12 @@ def test_run_normal_path_builds_graph(mocker): return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), ) mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) - mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + class FakeVariablePool: + def add(self, selector, value): + return None + + mocker.patch.object(module, "VariablePool", return_value=FakeVariablePool()) workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index a3ced02394..b0f8b423e1 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,9 +1,7 @@ -from unittest.mock import MagicMock - import pytest from core.app.apps.base_app_generator import BaseAppGenerator -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -403,11 +401,11 @@ class TestBaseAppGeneratorExtras: monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mapping", - lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object", ) monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mappings", - lambda mappings, tenant_id, config: ["file-1", "file-2"], + lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"], ) user_inputs = { @@ -479,7 +477,7 @@ class TestBaseAppGeneratorExtras: def test_get_draft_var_saver_factory_debugger(self): from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.enums import BuiltinNodeTypes + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() @@ -489,7 +487,6 @@ class TestBaseAppGeneratorExtras: factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) saver = factory( - session=MagicMock(), app_id="app-id", node_id="node-id", node_type=BuiltinNodeTypes.START, diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py index c6dc20ffc6..842d14bbd2 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -59,3 +59,18 @@ class TestBaseAppQueueManager: bad = SimpleNamespace(_sa_instance_state=True) with pytest.raises(TypeError): manager._check_for_sqlalchemy_models(bad) + + def test_stop_listen_defers_graph_runtime_state_cleanup_until_listener_exits(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = None + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + runtime_state = SimpleNamespace(name="runtime-state") + manager.graph_runtime_state = runtime_state + + manager.stop_listen() + + assert manager.graph_runtime_state is runtime_state + assert list(manager.listen()) == [] + assert manager.graph_runtime_state is None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index aabeb54553..17de39ca99 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -14,15 +14,15 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 2f73a8cda8..3673b7f68e 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -3,33 +3,33 @@ import time from types import ModuleType, SimpleNamespace from typing import Any -import dify_graph.nodes.human_input.entities # noqa: F401 +import graphon.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_system_variables +from graphon.entities.base_node_data import BaseNodeData, RetryConfig +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import SchedulingPause +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult, PauseRequestedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult, PauseRequestedEvent +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.base.node import Node +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -162,11 +162,11 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) - variable_pool.system_variables.workflow_execution_id = run_id + variable_pool.add(("sys", "workflow_run_id"), run_id) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 3f1dd14569..58c7bfa4bc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from types import SimpleNamespace import pytest @@ -11,25 +11,35 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( +from core.workflow.system_variables import default_system_variables +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, NodeRunIterationSucceededEvent, NodeRunLoopFailedEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, + NodeRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.variables import StringVariable class TestWorkflowBasedAppRunner: @@ -44,7 +54,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -78,12 +88,12 @@ class TestWorkflowBasedAppRunner: workflow = SimpleNamespace(environment_variables=[], graph_dict={}) with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): - runner._prepare_single_node_execution(workflow, None, None) + runner._prepare_single_node_execution(workflow, None, None, user_id="00000000-0000-0000-0000-000000000001") def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -126,11 +136,102 @@ class TestWorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id="00000000-0000-0000-0000-000000000001", ) assert graph is not None assert variable_pool is graph_runtime_state.variable_pool + def test_get_graph_and_variable_pool_preloads_constructor_variables_before_graph_init(self, monkeypatch): + variable_loader = SimpleNamespace( + load_variables=lambda selectors: ( + [ + StringVariable( + name="conversation_id", + value="conv-1", + selector=["sys", "conversation_id"], + ) + ] + if selectors + else [] + ) + ) + runner = WorkflowBasedAppRunner( + queue_manager=SimpleNamespace(), + variable_loader=variable_loader, + app_id="app", + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + + workflow = SimpleNamespace( + tenant_id="tenant", + id="workflow", + graph_dict={ + "nodes": [ + {"id": "loop-node", "data": {"type": "loop", "version": "1", "title": "Loop"}}, + { + "id": "llm-child", + "data": { + "type": "llm", + "version": "1", + "loop_id": "loop-node", + "memory": object(), + }, + }, + ], + "edges": [], + }, + ) + + class _LoopNodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + def _validate_node_config(value): + return {"id": value["id"], "data": SimpleNamespace(**value["data"])} + + def _graph_init(**kwargs): + variable_pool = graph_runtime_state.variable_pool + assert variable_pool.get(["sys", "conversation_id"]) is not None + return SimpleNamespace() + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.NodeConfigDictAdapter.validate_python", + _validate_node_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + _graph_init, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.resolve_workflow_node_class", + lambda **_kwargs: _LoopNodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert graph is not None + assert variable_pool.get(["sys", "conversation_id"]).value == "conv-1" + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): published: list[object] = [] @@ -140,7 +241,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -183,7 +284,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -195,7 +296,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.START, node_title="Start", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), ), ) runner._handle_event( @@ -232,7 +333,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Iter", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={"ok": True}, metadata={}, @@ -246,7 +347,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Loop", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={}, metadata={}, @@ -259,3 +360,87 @@ class TestWorkflowBasedAppRunner: assert any(isinstance(event, QueueAgentLogEvent) for event in published) assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) + + @pytest.mark.parametrize( + ("event_factory", "queue_event_cls"), + [ + ( + lambda result, start_at, finished_at: NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + node_run_result=result, + ), + QueueNodeSucceededEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeFailedEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeExceptionEvent, + ), + ( + lambda result, start_at, _finished_at: NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=start_at, + error="boom", + retry_index=1, + node_run_result=result, + ), + QueueNodeRetryEvent, + ), + ], + ) + def test_handle_start_node_result_events_project_outputs(self, event_factory, queue_event_cls): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + started_at = datetime.now(UTC) + finished_at = datetime.now(UTC) + result = NodeRunResult( + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + "conversation.session_id": "session-1", + }, + ) + + runner._handle_event(workflow_entry, event_factory(result, started_at, finished_at)) + + queue_event = published[-1] + assert isinstance(queue_event, queue_event_cls) + assert queue_event.outputs == {"question": "hello"} diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 1388279221..38a947986f 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -4,8 +4,8 @@ import pytest from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph_events.graph import GraphRunPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118e..620a153204 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -9,15 +9,15 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow def _make_graph_state(): variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -100,6 +100,7 @@ def test_run_uses_single_node_execution_branch( workflow=workflow, single_iteration_run=single_iteration_run, single_loop_run=single_loop_run, + user_id="user", ) init_graph.assert_not_called() @@ -158,6 +159,7 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id="00000000-0000-0000-0000-000000000001", ) assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 65c6bd6654..ef0edf4096 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -10,13 +10,14 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.pause_reason import HumanInputRequired +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account +from models.human_input import RecipientType class _RecordingWorkflowAppRunner(WorkflowAppRunner): @@ -74,7 +75,6 @@ def test_graph_run_paused_event_emits_queue_pause_event(): actions=[], node_id="node-human", node_title="Human Step", - form_token="tok", ) event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) workflow_entry = SimpleNamespace( @@ -98,7 +98,7 @@ def _build_converter(): invoke_from=InvokeFrom.SERVICE_API, app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="user", app_id="app-id", workflow_id="workflow-id", @@ -128,7 +128,21 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon class _FakeSession: def execute(self, _stmt): - return [("form-1", expiration_time)] + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def scalars(self, _stmt): + return [ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.BACKSTAGE, + access_token="backstage-token", + ), + ] def __enter__(self): return self @@ -146,10 +160,8 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), ], actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, node_id="node-id", node_title="Human Step", - form_token="token", ) queue_event = QueueWorkflowPausedEvent( reasons=[reason], @@ -170,7 +182,6 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert pause_resp.data.paused_nodes == ["node-id"] assert pause_resp.data.outputs == {} assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True assert isinstance(responses[0], HumanInputRequiredResponse) hi_resp = responses[0] @@ -180,4 +191,5 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert hi_resp.data.inputs[0].output_variable_name == "field" assert hi_resp.data.actions[0].id == "approve" assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "backstage-token" assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 62e94a7580..7dd7ffd727 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -9,7 +9,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 5b23e71035..a0a999cbc5 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -7,11 +7,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode +from tests.workflow_test_utils import build_test_variable_pool def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: @@ -37,11 +38,7 @@ def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) + variable_pool = build_test_variable_pool(variables=build_system_variables(workflow_execution_id=run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index f35710d207..115e35da8a 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -1,7 +1,6 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest @@ -44,11 +43,13 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -164,7 +165,7 @@ class TestWorkflowGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -191,7 +192,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -205,7 +206,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -244,7 +245,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -257,7 +258,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -302,7 +303,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -318,7 +319,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -326,7 +327,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -342,7 +343,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) filled_event = QueueHumanInputFormFilledEvent( @@ -358,7 +359,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) agent_event = QueueAgentLogEvent( id="log", @@ -451,7 +452,7 @@ class TestWorkflowGenerateTaskPipeline: ) assert pipeline._created_by_role == CreatorUserRole.END_USER - assert pipeline._workflow_system_variables.user_id == "session-id" + assert system_variables_to_mapping(pipeline._workflow_system_variables)["user_id"] == "session-id" def test_process_returns_stream_and_blocking_variants(self): pipeline = _make_pipeline() @@ -647,7 +648,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", error="error", @@ -659,7 +660,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", ) @@ -684,7 +685,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec-id", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -699,7 +700,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -727,7 +728,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -743,7 +744,7 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( @@ -815,7 +816,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) assert len(added) == count_before - def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + def test_save_output_for_event_writes_draft_variables(self): pipeline = _make_pipeline() saver_calls: list[tuple[object, object]] = [] captured_factory_args: dict[str, object] = {} @@ -828,36 +829,14 @@ class TestWorkflowGenerateTaskPipeline: captured_factory_args.update(kwargs) return _Saver() - class _Begin: - def __enter__(self): - return None - - def __exit__(self, exc_type, exc, tb): - return False - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return _Begin() - pipeline._draft_var_saver_factory = _factory - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) event = QueueNodeSucceededEvent( node_execution_id="exec-id", node_id="node-id", node_type=BuiltinNodeTypes.START, in_loop_id="loop-id", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), process_data={"k": "v"}, outputs={"out": 1}, ) diff --git a/api/tests/unit_tests/core/app/entities/test_queue_entities.py b/api/tests/unit_tests/core/app/entities/test_queue_entities.py new file mode 100644 index 0000000000..7c21b00966 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_queue_entities.py @@ -0,0 +1,12 @@ +from core.app.entities.queue_entities import QueueStopEvent + + +class TestQueueEntities: + def test_get_stop_reason_for_known_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + assert event.get_stop_reason() == "Stopped by user." + + def test_get_stop_reason_for_unknown_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + event.stopped_by = "unknown" + assert event.get_stop_reason() == "Stopped by unknown reason." diff --git a/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py new file mode 100644 index 0000000000..1e0ef6d6d6 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py @@ -0,0 +1,17 @@ +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity + + +class TestRagPipelineInvokeEntity: + def test_defaults_and_fields(self): + entity = RagPipelineInvokeEntity( + pipeline_id="pipe-1", + application_generate_entity={"foo": "bar"}, + user_id="user-1", + tenant_id="tenant-1", + workflow_id="workflow-1", + streaming=True, + ) + + assert entity.workflow_execution_id is None + assert entity.workflow_thread_pool_id is None + assert entity.streaming is True diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py new file mode 100644 index 0000000000..7c79780641 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -0,0 +1,78 @@ +from core.app.entities.task_entities import ( + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + StreamEvent, +) +from graphon.enums import WorkflowNodeExecutionStatus + + +class TestTaskEntities: + def test_node_start_to_ignore_detail_dict(self): + data = NodeStartStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + created_at=1, + ) + response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_STARTED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["extras"] == {} + + def test_node_finish_to_ignore_detail_dict(self): + data = NodeFinishStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ) + response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_FINISHED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["outputs"] is None + assert payload["data"]["files"] == [] + + def test_node_retry_to_ignore_detail_dict(self): + data = NodeRetryStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.RETRY, + elapsed_time=0.1, + created_at=1, + finished_at=2, + retry_index=2, + ) + response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_RETRY.value + assert payload["data"]["retry_index"] == 2 + assert payload["data"]["outputs"] is None diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c72..538b130cac 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" diff --git a/api/tests/unit_tests/core/app/features/test_annotation_reply.py b/api/tests/unit_tests/core/app/features/test_annotation_reply.py new file mode 100644 index 0000000000..e721a77079 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_annotation_reply.py @@ -0,0 +1,163 @@ +import logging +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature + + +class TestAnnotationReplyFeature: + def test_query_returns_none_when_setting_missing(self): + feature = AnnotationReplyFeature() + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = None + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_none_when_binding_missing(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace(collection_binding_detail=None) + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = annotation_setting + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_annotation_and_records_history_for_api(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result == annotation + mock_annotation_service.add_annotation_history.assert_called_once() + _, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "api" + assert score == 0.8 + + def test_query_returns_annotation_and_records_history_for_console(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=0.5, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=None, + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.EXPLORE, + ) + + assert result == annotation + _, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "console" + + def test_query_logs_and_returns_none_on_exception(self, caplog): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1") + mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom") + + with caplog.at_level(logging.WARNING): + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + assert "Query annotation failed" in caplog.text diff --git a/api/tests/unit_tests/core/app/features/test_hosting_moderation.py b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py new file mode 100644 index 0000000000..01194c16f5 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py @@ -0,0 +1,30 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature + + +class TestHostingModerationFeature: + def test_check_aggregates_text_and_calls_moderation(self): + application_generate_entity = Mock() + application_generate_entity.model_conf = {"model": "mock"} + application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1") + + prompt_messages = [ + SimpleNamespace(content="hello"), + SimpleNamespace(content=123), + SimpleNamespace(content="world"), + ] + + with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check: + mock_check.return_value = True + + feature = HostingModerationFeature() + result = feature.check(application_generate_entity, prompt_messages) + + assert result is True + mock_check.assert_called_once_with( + tenant_id="tenant-1", + model_config=application_generate_entity.model_conf, + text="hello\nworld\n", + ) diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index bdc889d941..279e315946 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,18 +1,17 @@ from collections.abc import Sequence -from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringVariable -from dify_graph.variables.segments import Segment +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.protocols.command_channel import CommandChannel +from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment +from libs.datetime_utils import naive_utc_now class MockReadOnlyVariablePool: @@ -36,31 +35,38 @@ def _build_graph_runtime_state( conversation_id: str | None = None, ) -> ReadOnlyGraphRuntimeState: graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + if conversation_id is not None: + variable_pool._variables[("sys", SystemVariableKey.CONVERSATION_ID.value)] = StringSegment( + value=conversation_id + ) graph_runtime_state.variable_pool = variable_pool - graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() return graph_runtime_state -def _build_node_run_succeeded_event( - *, - node_type: NodeType, - outputs: dict[str, object] | None = None, - process_data: dict[str, object] | None = None, -) -> NodeRunSucceededEvent: +def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="node-exec-id", node_id="assigner", - node_type=node_type, - start_at=datetime.utcnow(), + node_type=BuiltinNodeTypes.LLM, + start_at=naive_utc_now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs or {}, - process_data=process_data or {}, + outputs={}, + process_data={}, ), ) -def test_persists_conversation_variables_from_assigner_output(): +def _build_variable_updated_event(variable: StringVariable) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id="node-exec-id", + node_id="assigner", + node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, + variable=variable, + ) + + +def test_persists_conversation_variables_from_variable_update_event(): conversation_id = "conv-123" variable = StringVariable( id="var-1", @@ -68,55 +74,26 @@ def test_persists_conversation_variables_from_assigner_output(): value="updated", selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(variable) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) - updater.flush.assert_called_once() -def test_skips_when_outputs_missing(): +def test_skips_non_variable_update_events(): conversation_id = "conv-456" - variable = StringVariable( - id="var-2", - name="name", - value="updated", - selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event() layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() - - -def test_skips_non_assigner_nodes(): - updater = Mock() - layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) - layer.on_event(event) - - updater.update.assert_not_called() - updater.flush.assert_not_called() def test_skips_non_conversation_variables(): @@ -127,18 +104,11 @@ def test_skips_non_conversation_variables(): value="updated", selector=["environment", "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] - ) - - variable_pool = MockReadOnlyVariablePool() - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(non_conversation_variable) layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035f0ee05c..92a7788f6e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,17 +13,18 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import ( +from core.workflow.system_variables import SystemVariableKey +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from dify_graph.variables.segments import Segment +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from graphon.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -51,17 +52,6 @@ class TestDataFactory: return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) -class MockSystemVariableReadOnlyView: - """Minimal read-only system variable view for testing.""" - - def __init__(self, workflow_execution_id: str | None = None) -> None: - self._workflow_execution_id = workflow_execution_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - class MockReadOnlyVariablePool: """Mock implementation of ReadOnlyVariablePool for testing.""" @@ -76,13 +66,14 @@ class MockReadOnlyVariablePool: return None mock_segment = Mock(spec=Segment) mock_segment.value = value + mock_segment.text = value if isinstance(value, str) else None return mock_segment def get_all_by_node(self, node_id: str) -> dict[str, object]: return {key: value for (nid, key), value in self._variables.items() if nid == node_id} def get_by_prefix(self, prefix: str) -> dict[str, object]: - return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} class MockReadOnlyGraphRuntimeState: @@ -105,12 +96,10 @@ class MockReadOnlyGraphRuntimeState: self._ready_queue_size = ready_queue_size self._exceptions_count = exceptions_count self._outputs = outputs or {} - self._variable_pool = MockReadOnlyVariablePool(variables) - self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) - - @property - def system_variable(self) -> MockSystemVariableReadOnlyView: - return self._system_variable + resolved_variables = dict(variables or {}) + if workflow_execution_id is not None: + resolved_variables[("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value)] = workflow_execution_id + self._variable_pool = MockReadOnlyVariablePool(resolved_variables) @property def variable_pool(self) -> ReadOnlyVariablePool: @@ -161,7 +150,9 @@ class MockReadOnlyGraphRuntimeState: "exceptions_count": self._exceptions_count, "outputs": self._outputs, "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, - "workflow_execution_id": self._system_variable.workflow_execution_id, + "workflow_execution_id": self._variable_pool._variables.get( + ("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value) + ), } ) diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py new file mode 100644 index 0000000000..56705f1a7e --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -0,0 +1,19 @@ +from core.app.layers.suspend_layer import SuspendLayer +from graphon.graph_events.graph import GraphRunPausedEvent + + +class TestSuspendLayer: + def test_on_event_accepts_paused_event(self): + layer = SuspendLayer() + assert layer.is_paused() is False + layer.on_graph_start() + assert layer.is_paused() is False + layer.on_event(GraphRunPausedEvent()) + assert layer.is_paused() is True + + def test_on_event_ignores_other_events(self): + layer = SuspendLayer() + layer.on_graph_start() + initial_state = layer.is_paused() + layer.on_event(object()) + assert layer.is_paused() is initial_state diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py new file mode 100644 index 0000000000..1ac9a4d8c0 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -0,0 +1,98 @@ +from unittest.mock import Mock, patch + +from core.app.layers.timeslice_layer import TimeSliceLayer +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import SchedulerCommand + + +class TestTimeSliceLayer: + def test_init_starts_scheduler_when_not_running(self): + scheduler = Mock() + scheduler.running = False + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + _ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock())) + + scheduler.start.assert_called_once() + + def test_on_graph_start_adds_job_for_time_slice(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid, + ): + mock_uuid.return_value.hex = "job-1" + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.on_graph_start() + + assert layer.schedule_id == "job-1" + scheduler.add_job.assert_called_once() + + def test_on_graph_end_removes_job(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.schedule_id = "job-1" + layer.on_graph_end(None) + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_removes_when_stopped(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.stopped = True + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_handles_resource_limit_without_command_channel(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.logger") as mock_logger, + ): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + mock_logger.exception.assert_called_once() + + def test_checker_job_sends_pause_command(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.command_channel = Mock() + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + layer.command_channel.send_command.assert_called_once() + sent_command = layer.command_channel.send_command.call_args[0][0] + assert isinstance(sent_command, GraphEngineCommand) + assert sent_command.command_type == CommandType.PAUSE diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py new file mode 100644 index 0000000000..ecc431936c --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -0,0 +1,108 @@ +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.layers.trigger_post_layer import TriggerPostLayer +from core.workflow.system_variables import build_system_variables +from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool +from models.enums import WorkflowTriggerStatus + + +class TestTriggerPostLayer: + def test_on_event_updates_trigger_log(self): + trigger_log = SimpleNamespace( + status=None, + workflow_run_id=None, + outputs=None, + elapsed_time=None, + total_tokens=None, + finished_at=None, + ) + runtime_state = SimpleNamespace( + outputs={"answer": "ok"}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=12, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime, + ): + mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC) + + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = trigger_log + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunSucceededEvent()) + + assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED + assert trigger_log.workflow_run_id == "run-1" + assert trigger_log.outputs is not None + assert trigger_log.elapsed_time is not None + assert trigger_log.total_tokens == 12 + assert trigger_log.finished_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + + def test_on_event_handles_missing_trigger_log(self): + runtime_state = SimpleNamespace( + outputs={}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=0, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.logger") as mock_logger, + ): + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = None + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="missing", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunFailedEvent(error="boom")) + + mock_logger.exception.assert_called_once() + session.commit.assert_not_called() + + def test_on_event_ignores_non_status_events(self): + runtime_state = SimpleNamespace( + outputs={}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=0, + ) + + with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory: + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(Mock()) + + mock_session_factory.create_session.assert_not_called() diff --git a/api/tests/unit_tests/core/app/task_pipeline/__init__.py b/api/tests/unit_tests/core/app/task_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py new file mode 100644 index 0000000000..c246f7b783 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.entities.queue_entities import QueueErrorEvent +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from models.enums import MessageStatus + + +class TestBasedGenerateTaskPipeline: + @pytest.fixture + def pipeline(self): + app_config = SimpleNamespace( + tenant_id="tenant-1", + app_id="app-1", + sensitive_word_avoidance=None, + ) + app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config) + return BasedGenerateTaskPipeline( + application_generate_entity=app_generate_entity, + queue_manager=Mock(), + stream=True, + ) + + def test_error_to_desc_quota_exceeded(self, pipeline): + message = pipeline._error_to_desc(QuotaExceededError()) + assert "quota" in message.lower() + + def test_handle_error_wraps_invoke_authorization(self, pipeline): + event = QueueErrorEvent(error=InvokeAuthorizationError()) + err = pipeline.handle_error(event=event) + assert isinstance(err, InvokeAuthorizationError) + assert str(err) == "Incorrect API key provided" + + def test_handle_error_preserves_invoke_error(self, pipeline): + event = QueueErrorEvent(error=InvokeError("bad")) + err = pipeline.handle_error(event=event) + assert err is event.error + + def test_handle_error_updates_message_when_found(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + message = SimpleNamespace(status=MessageStatus.NORMAL, error=None) + session = Mock() + session.scalar.return_value = message + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + assert message.status == MessageStatus.ERROR + assert message.error == "oops" + + def test_handle_error_returns_err_when_message_missing(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + session = Mock() + session.scalar.return_value = None + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + + def test_error_to_stream_response_and_ping(self, pipeline): + error_response = pipeline.error_to_stream_response(ValueError("boom")) + ping_response = pipeline.ping_stream_response() + + assert error_response.task_id == "task-1" + assert ping_response.task_id == "task-1" + + def test_handle_output_moderation_when_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("filtered", True) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result == "filtered" + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None + + def test_handle_output_moderation_when_not_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("safe", False) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result is None + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 13fbca6e26..1c1bf391d3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -26,8 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py new file mode 100644 index 0000000000..ea000f3886 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -0,0 +1,1228 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + ChatbotAppStreamResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AudioTrunk +from graphon.file.enums import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent +from models.model import AppMode + + +class _DummyModelConf: + def __init__(self) -> None: + self.model = "mock" + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock", model="mock"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_entity(entity_cls, app_mode: AppMode): + app_config = _make_app_config(app_mode) + return entity_cls.model_construct( + task_id="task", + app_config=app_config, + model_conf=_DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +class TestEasyUiBasedGenerateTaskPipeline: + def test_to_blocking_response_chat(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_to_blocking_response_completion(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_listen_audio_msg_returns_none_when_no_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_process_stream_response_handles_chunks_and_end(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[TextPromptMessageContent(data="hi"), TextPromptMessageContent(data="yo")] + ), + ), + ) + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + events = [ + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + SimpleNamespace(event=QueueMessageReplaceEvent(text="replace", reason="output_moderation")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueMessageEndEvent(llm_result=llm_result)), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: "chunk" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline.handle_output_moderation_when_task_finished = lambda completion: None + pipeline._message_end_to_stream_response = lambda: "end" + pipeline._save_message = lambda **kwargs: None + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "chunk" in responses + assert "replace" in responses + assert any(isinstance(item, PingStreamResponse) for item in responses) + assert responses[-1] == "end" + + def test_handle_output_moderation_chunk_directs_output(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline.output_moderation_handler = _Moderation() + pipeline.queue_manager.publish = lambda event, publish_from: events.append(event) + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is True + assert any(isinstance(event, QueueLLMChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_stop_updates_usage(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + class _ModelType: + def calc_response_usage(self, model, credentials, prompt_tokens, completion_tokens): + return LLMUsage.from_metadata( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + ) + + class _ModelConf: + def __init__(self) -> None: + self.model = "mock" + self.credentials = {} + self.provider_model_bundle = SimpleNamespace(model_type_instance=_ModelType()) + + app_config = _make_app_config(AppMode.CHAT) + application_generate_entity = ChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + model_conf=_ModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content="answer") + + calls: list[int] = [] + + class _FakeModelInstance: + def __init__(self, provider_model_bundle, model): + pass + + def get_llm_num_tokens(self, messages): + calls.append(1) + return 10 if len(calls) == 1 else 5 + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.ModelInstance", + _FakeModelInstance, + ) + + pipeline._handle_stop(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) + + assert pipeline._task_state.llm_result.usage.prompt_tokens == 10 + assert pipeline._task_state.llm_result.usage.completion_tokens == 5 + + def test_record_files_builds_file_payloads(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + message_files = [ + SimpleNamespace( + id="mf-1", + message_id="msg", + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/a.png", + upload_file_id=None, + type="image", + ), + SimpleNamespace( + id="mf-2", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-1", + type="image", + ), + SimpleNamespace( + id="mf-3", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/file.bin", + upload_file_id=None, + type="file", + ), + ] + upload_files = [ + SimpleNamespace( + id="upload-1", + name="local.png", + mime_type="image/png", + size=123, + extension="png", + ) + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else upload_files) + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "signed-url", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "signed-tool", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files + assert len(files) == 3 + + def test_process_stream_response_handles_annotation_and_error(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="agent"), + ), + ) + + events = [ + SimpleNamespace(event=QueueAnnotationReplyEvent(message_annotation_id="ann")), + SimpleNamespace(event=QueueAgentThoughtEvent(agent_thought_id="thought")), + SimpleNamespace(event=QueueMessageFileEvent(message_file_id="file")), + SimpleNamespace(event=QueueAgentMessageEvent(chunk=agent_chunk)), + SimpleNamespace(event=QueueErrorEvent(error=ValueError("boom"))), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.handle_annotation_reply = lambda event: SimpleNamespace(content="annotated") + pipeline._agent_thought_to_stream_response = lambda event: "thought" + pipeline._message_cycle_manager.message_file_to_stream_response = lambda event: "file" + pipeline._agent_message_to_stream_response = lambda **kwargs: "agent" + pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline.error_to_stream_response = lambda err: err + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "thought" in responses + assert "file" in responses + assert "agent" in responses + assert isinstance(responses[-1], ValueError) + assert pipeline._task_state.llm_result.message.content == "annotatedagent" + + def test_agent_thought_to_stream_response_returns_payload(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_thought = SimpleNamespace( + id="thought", + position=1, + thought="t", + observation="o", + tool="tool", + tool_labels={}, + tool_input="input", + files=[], + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return agent_thought + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="thought")) + + assert response is not None + assert response.id == "thought" + + def test_process_routes_to_stream_and_starts_conversation_name_generation(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock(return_value=object()) + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_stream_response = lambda generator: "streamed" + + result = pipeline.process() + + assert result == "streamed" + pipeline._message_cycle_manager.generate_conversation_name.assert_called_once_with( + conversation_id="conv", query="hello" + ) + + def test_process_routes_to_blocking_for_completion_mode(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock() + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_blocking_response = lambda generator: "blocking" + + result = pipeline.process() + + assert result == "blocking" + pipeline._message_cycle_manager.generate_conversation_name.assert_not_called() + + def test_to_blocking_response_raises_error_stream_exception(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("stream error")) + + with pytest.raises(ValueError, match="stream error"): + pipeline._to_blocking_response(_gen()) + + def test_to_blocking_response_raises_when_generator_ends_without_message_end(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(RuntimeError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_gen()) + + def test_to_stream_response_wraps_completion_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, CompletionAppStreamResponse) + assert response.message_id == "msg" + + def test_to_stream_response_wraps_chat_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, ChatbotAppStreamResponse) + assert response.conversation_id == "conv" + + def test_listen_audio_msg_returns_audio_response_for_non_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("responding", "abc")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + assert response.audio == "abc" + + def test_listen_audio_msg_returns_none_for_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("finish", "abc")) + + assert pipeline._listen_audio_msg(publisher=publisher, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses == ["payload"] + + def test_wrapper_process_stream_response_with_tts_publisher(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def check_and_get_audio(self): + return AudioTrunk("finish", "") + + inline_audio = MessageAudioStreamResponse(task_id="task", audio="inline") + audio_calls = iter([inline_audio, None]) + pipeline._listen_audio_msg = lambda publisher, task_id: next(audio_calls) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses[0] == inline_audio + assert responses[1] == "payload" + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_wrapper_process_stream_response_timeout_yields_audio_chunk(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def __init__(self): + self._events = iter([None, AudioTrunk("responding", "later"), AudioTrunk("finish", "")]) + + def check_and_get_audio(self): + return next(self._events) + + clock = {"value": 0.0} + + def _fake_time(): + clock["value"] += 0.1 + return clock["value"] + + pipeline._process_stream_response = lambda publisher, trace_manager: iter([]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.time", _fake_time) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.sleep", lambda _: None) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_process_stream_response_handles_stop_event_and_output_replacement(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._task_state.llm_result.message.content = "raw answer" + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_stop = Mock() + pipeline.handle_output_moderation_when_task_finished = lambda answer: "moderated answer" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda answer: f"replace:{answer}" + pipeline._save_message = lambda **kwargs: None + pipeline._message_end_to_stream_response = lambda: "end" + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["replace:moderated answer", "end"] + pipeline._handle_stop.assert_called_once() + + def test_process_stream_response_handles_retriever_unknown_and_empty_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=None)), + ) + handled = {"retriever": 0} + + def _handle_retriever_resources(event): + handled["retriever"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _handle_retriever_resources + pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=retriever_event), + SimpleNamespace(event=SimpleNamespace()), + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + ] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + assert handled["retriever"] == 1 + + def test_process_stream_response_skips_when_output_moderation_directs_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="x")), + ) + pipeline._handle_output_moderation_chunk = lambda text: True + pipeline.queue_manager.listen = lambda: iter([SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk))]) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + + def test_process_stream_response_ignores_unsupported_chunk_content_types(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = SimpleNamespace( + prompt_messages=[], + delta=SimpleNamespace(message=SimpleNamespace(content=[object(), "ok"])), + ) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: kwargs["answer"] + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueLLMChunkEvent.model_construct(chunk=chunk))] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["ok"] + + def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._conversation_name_generate_thread = object() + pipeline.queue_manager.listen = lambda: iter([]) + + assert list(pipeline._process_stream_response(publisher=None)) == [] + + def test_save_message_persists_fields_and_emits_trace(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline.start_at = 10.0 + pipeline._model_config = SimpleNamespace(mode="chat") + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content=" {{name}} hello ") + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata( + {"prompt_tokens": 3, "completion_tokens": 5, "total_price": "1.23"} + ) + + message_obj = SimpleNamespace(id="msg") + conversation_obj = SimpleNamespace(id="conv") + session = Mock() + session.scalar.side_effect = [message_obj, conversation_obj] + trace_manager = SimpleNamespace(add_trace_task=Mock()) + sent_payloads: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptMessageUtil.prompt_messages_to_prompt_for_saving", + lambda mode, prompt_messages: "serialized-prompt", + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptTemplateParser.remove_template_variables", + lambda text: text.replace("{{name}}", "").strip(), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.naive_utc_now", + lambda: datetime(2024, 1, 1, tzinfo=UTC), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.perf_counter", lambda: 15.0 + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.message_was_created.send", + lambda *args, **kwargs: sent_payloads.append((args, kwargs)), + ) + + pipeline._save_message(session=session, trace_manager=trace_manager) + + assert message_obj.message == "serialized-prompt" + assert message_obj.answer == "hello" + assert message_obj.provider_response_latency == 5.0 + assert trace_manager.add_trace_task.called + assert len(sent_payloads) == 1 + + def test_save_message_raises_when_message_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.return_value = None + + with pytest.raises(ValueError, match="message msg not found"): + pipeline._save_message(session=session) + + def test_save_message_raises_when_conversation_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.side_effect = [SimpleNamespace(id="msg"), None] + + with pytest.raises(ValueError, match="Conversation conv not found"): + pipeline._save_message(session=session) + + def test_message_end_to_stream_response_includes_usage_metadata(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 2}) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.id == "msg" + assert response.metadata["usage"]["prompt_tokens"] == 1 + + def test_record_files_returns_none_when_message_has_no_files(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.files is None + + def test_record_files_handles_local_fallback_and_tool_url_variants(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + message_files = [ + SimpleNamespace( + id="mf-local-fallback", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-missing", + type="file", + ), + SimpleNamespace( + id="mf-tool-http", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="http://cdn.example.com/file.txt?x=1", + upload_file_id=None, + type="file", + ), + SimpleNamespace( + id="mf-tool-noext", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/path/toolid", + upload_file_id=None, + type="file", + ), + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else []) + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "local-fallback-signed", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "tool-signed", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files is not None + assert files[0]["url"] == "local-fallback-signed" + assert files[1]["filename"] == "file.txt" + assert files[2]["extension"] == ".bin" + + def test_agent_message_to_stream_response_builds_payload(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + response = pipeline._agent_message_to_stream_response(answer="hello", message_id="msg") + + assert response.id == "msg" + assert response.answer == "hello" + + def test_agent_thought_to_stream_response_returns_none_when_not_found(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="missing")) + + assert response is None + + def test_handle_output_moderation_chunk_appends_token_when_not_directing(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + appended_tokens: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + appended_tokens.append(text) + + pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("next-token") + + assert result is False + assert appended_tokens == ["next-token"] diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a..abfbcdb941 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod +from graphon.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_exc.py b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py new file mode 100644 index 0000000000..9ea7e96e73 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py @@ -0,0 +1,11 @@ +from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError + + +class TestTaskPipelineExceptions: + def test_record_not_found_error_message(self): + err = RecordNotFoundError("Message", "msg-1") + assert str(err) == "Message with id msg-1 not found" + + def test_workflow_run_not_found_error_message(self): + err = WorkflowRunNotFoundError("run-1") + assert str(err) == "WorkflowRun with id run-1 not found" diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index c0c636715d..07ee75ed35 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,12 +1,16 @@ """Unit tests for the message cycle manager optimization.""" +from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from flask import current_app +from flask import Flask, current_app -from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from models.model import AppMode class TestMessageCycleManagerOptimization: @@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE mock_session.scalar.assert_called_once() + def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager): + """Return MESSAGE_FILE directly from in-memory cache without opening a DB session.""" + message_cycle_manager._message_has_file.add("cached-message") + + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + result = message_cycle_manager.get_message_event_type("cached-message") + + assert result == StreamEvent.MESSAGE_FILE + mock_session_factory.create_session.assert_not_called() + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization: assert chunk2_response.event == StreamEvent.MESSAGE assert chunk1_response.answer == "Chunk 1" assert chunk2_response.answer == "Chunk 2" + + def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager): + """Return None when completion entities are used for conversation naming. + + Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity. + Returns: None, indicating no name generation for completion apps. + Side effects: None expected. + """ + + class DummyCompletion: + pass + + with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion): + message_cycle_manager._application_generate_entity = DummyCompletion() + result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi") + + assert result is None + + def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager): + """Spawn background generation thread for the first chat message.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True} + flask_app = object() + + class DummyTimer: + def __init__(self, interval, function, args=None, kwargs=None): + self.interval = interval + self.function = function + self.args = args or [] + self.kwargs = kwargs + self.daemon = False + self.started = False + + def start(self): + self.started = True + + with ( + patch( + "core.app.task_pipeline.message_cycle_manager.current_app", + new=SimpleNamespace(_get_current_object=lambda: flask_app), + ), + patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer), + ): + thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello") + + assert isinstance(thread, DummyTimer) + assert thread.interval == 1 + assert thread.function == message_cycle_manager._generate_conversation_name_worker + assert thread.started is True + assert thread.daemon is True + assert thread.kwargs["flask_app"] is flask_app + assert thread.kwargs["conversation_id"] == "conv-1" + assert thread.kwargs["query"] == "hello" + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + + def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager): + """Skip thread creation when auto naming is disabled but still mark conversation as not new.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False} + + with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer: + result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello") + + assert result is None + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + mock_timer.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager): + """Return early when the conversation cannot be found.""" + flask_app = Flask(__name__) + db_session = Mock() + db_session.scalar.return_value = None + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager): + """Return early when non-completion conversation has no app relation.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id") + db_session = Mock() + db_session.scalar.return_value = conversation + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager): + """Use cached conversation name when present and avoid LLM call.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = b"cached-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "cached-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_llm_generator.generate_conversation_name.assert_not_called() + mock_redis.setex.assert_not_called() + + def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager): + """Generate conversation name and write it to redis cache on cache miss.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.return_value = "generated-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "generated-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager): + """Fallback to truncated query when LLM generation fails.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + long_query = "q" * 60 + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config, + patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed") + mock_dify_config.DEBUG = True + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) + + assert conversation.name == (long_query[:47] + "...") + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_logger.exception.assert_called_once() + + def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager): + """Populate task metadata from annotation reply events. + + Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService. + Returns: The fetched annotation object. + Side effects: Updates metadata.annotation_reply with id and account name. + """ + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + annotation = SimpleNamespace( + id="ann-1", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = annotation + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="ann-1") + ) + + assert result == annotation + assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1" + assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice" + + def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager): + """Return None and keep metadata unchanged when annotation is not found.""" + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = None + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="missing") + ) + + assert result is None + assert message_cycle_manager._task_state.metadata.annotation_reply is None + + def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager): + """Merge retriever resources, deduplicate, and preserve ordering positions. + + Args: message_cycle_manager with show_retrieve_source enabled and existing metadata. + Returns: None. + Side effects: Updates metadata.retriever_resources with unique items and positions. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing])) + + duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2") + + event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource]) + message_cycle_manager.handle_retriever_resources(event) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2 + + def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager): + """Build a stream response with a signed tool file URL. + + Args: message_cycle_manager with mocked Session/db and sign_tool_file. + Returns: MessageStreamResponse with signed url and belongs_to normalized to user. + Side effects: Calls sign_tool_file for tool file ids. + """ + message_cycle_manager._application_generate_entity.task_id = "task-1" + + message_file = SimpleNamespace( + id="file-1", + type="image", + belongs_to=None, + url="tool://file.verylongextension", + message_id="msg-1", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-url" + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1")) + + assert response.url == "signed-url" + assert response.belongs_to == "user" + mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin") + + def test_handle_retriever_resources_requires_features(self, message_cycle_manager): + """Raise when retriever resources are handled without feature config. + + Args: message_cycle_manager with additional_features unset and empty metadata. + Raises: ValueError when show_retrieve_source configuration is missing. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with pytest.raises(ValueError): + message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[])) + + def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager): + """Ignore null resource entries while preserving valid resources.""" + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[])) + resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + + message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource])) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + + def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager): + """Use original URL when message file URL is already HTTP.""" + message_cycle_manager._application_generate_entity.task_id = "task-http" + message_file = SimpleNamespace( + id="file-http", + type="image", + belongs_to="assistant", + url="http://example.com/pic.png", + message_id="msg-http", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-http") + ) + + assert response is not None + assert response.url == "http://example.com/pic.png" + assert "msg-http" in message_cycle_manager._message_has_file + + def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager): + """Default tool file extension to .bin when URL has no extension part.""" + message_cycle_manager._application_generate_entity.task_id = "task-bin" + message_file = SimpleNamespace( + id="file-bin", + type="file", + belongs_to="assistant", + url="tool-file-id", + message_id="msg-bin", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-bin-url" + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-bin") + ) + + assert response is not None + assert response.url == "signed-bin-url" + mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin") + + def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager): + """Return None when message file lookup does not find a record.""" + session = Mock() + session.scalar.return_value = None + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing")) + + assert response is None + + def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager): + """Include the provided replacement reason in the stream payload.""" + response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation") + + assert response.answer == "replaced" + assert response.reason == "moderation" diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py new file mode 100644 index 0000000000..21c761c579 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -0,0 +1,57 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.entities import ModelConfigEntity +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from models.provider_ids import ModelProviderID + + +def test_validate_and_set_defaults_reuses_single_model_assembly(): + provider_name = str(ModelProviderID("openai")) + provider_entity = SimpleNamespace(provider=provider_name) + model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"}) + provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model]) + assembly = SimpleNamespace( + model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]), + provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations), + ) + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"stop": []}, + } + } + + with patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config) + + assert result["model"]["provider"] == provider_name + assert result["model"]["mode"] == "chat" + assert keys == ["model"] + mock_assembly.assert_called_once_with(tenant_id="tenant-1") + + +def test_convert_keeps_model_config_shape(): + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0.3, "stop": ["END"]}, + } + } + + result = ModelConfigManager.convert(config) + + assert result == ModelConfigEntity( + provider="openai", + model="gpt-4o-mini", + mode="chat", + parameters={"temperature": 0.3}, + stop=["END"], + ) diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 0f8a846d11..5c50cb78da 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -8,8 +8,8 @@ from core.app.workflow.layers.persistence import ( WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType -from dify_graph.node_events import NodeRunResult +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: @@ -58,3 +58,42 @@ def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.Mon assert node_execution.finished_at == event_finished_at assert node_execution.elapsed_time == 2.0 + + +def test_update_node_execution_projects_start_outputs() -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-2" + node_execution.node_type = BuiltinNodeTypes.START + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="start", + title="Start", + predecessor_node_id=None, + iteration_id=None, + loop_id=None, + created_at=node_execution.created_at, + ) + + layer._update_node_execution( + node_execution, + NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + }, + ), + WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + node_execution.update_from_mapping.assert_called_once_with( + inputs={"question": "hello"}, + process_data={}, + outputs={"question": "hello"}, + metadata={}, + ) diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py new file mode 100644 index 0000000000..cddd03f4b0 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope +from core.app.workflow import file_runtime +from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime +from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType +from models import ToolFile, UploadFile + + +def _build_file( + *, + transfer_method: FileTransferMethod, + reference: str | None = None, + remote_url: str | None = None, + extension: str | None = None, +) -> File: + return File( + id="file-id", + type=FileType.IMAGE, + transfer_method=transfer_method, + reference=reference, + remote_url=remote_url, + filename="diagram.png", + extension=extension, + mime_type="image/png", + size=128, + ) + + +def _build_runtime() -> DifyWorkflowFileRuntime: + return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()) + + +def test_resolve_file_url_returns_remote_url() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/diagram.png", + ) + + assert runtime.resolve_file_url(file=file) == "https://example.com/diagram.png" + + +def test_resolve_file_url_requires_file_reference() -> None: + runtime = _build_runtime() + file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None) + + with pytest.raises(ValueError, match="Missing file reference"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_requires_extension_for_tool_files() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=None, + ) + + with pytest.raises(ValueError, match="Missing file extension"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sign_tool_file = MagicMock(return_value="https://signed.example.com/file") + monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file) + runtime = _build_runtime() + + tool_file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=".png", + ) + datasource_file = _build_file( + transfer_method=FileTransferMethod.DATASOURCE_FILE, + reference=build_file_reference(record_id="datasource-file-id"), + extension=".png", + ) + + assert runtime.resolve_file_url(file=tool_file) == "https://signed.example.com/file" + assert runtime.resolve_file_url(file=datasource_file) == "https://signed.example.com/file" + assert sign_tool_file.call_count == 2 + + +def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr( + "core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", + "https://internal.example.com", + ) + + runtime = _build_runtime() + url = runtime.resolve_upload_file_url( + upload_file_id="upload-file-id", + as_attachment=True, + for_external=False, + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-file-id/file-preview" + assert query["as_attachment"] == ["true"] + assert query["timestamp"] == ["1700000000"] + + +def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60) + runtime = _build_runtime() + payload = "file-preview|upload-file-id|1700000000|nonce" + sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode() + + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is True + ) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000100) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is False + ) + + +def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: b"image-bytes") + + assert runtime.load_file_bytes(file=file) == b"image-bytes" + session.get.assert_called_with(UploadFile, "upload-file-id") + + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: "not-bytes") + with pytest.raises(ValueError, match="is not a bytes object"): + runtime.load_file_bytes(file=file) + + +def test_resolve_storage_key_ignores_encoded_reference_when_unscoped(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + session.get.assert_called_once_with(UploadFile, "upload-file-id") + + +def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key") + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id") + + +def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = None + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match="Upload file upload-file-id not found"): + runtime.resolve_upload_file_url(upload_file_id="upload-file-id") + + +@pytest.mark.parametrize( + ("transfer_method", "record_id", "expected_storage_key"), + [ + (FileTransferMethod.LOCAL_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.DATASOURCE_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.TOOL_FILE, "tool-file-id", "tool-storage-key"), + ], +) +def test_resolve_storage_key_loads_database_records( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + record_id: str, + expected_storage_key: str, +) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + + def get(model_class, value): + if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}: + assert model_class is UploadFile + return SimpleNamespace(key="upload-storage-key") + assert model_class is ToolFile + return SimpleNamespace(file_key="tool-storage-key") + + session.get.side_effect = get + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == expected_storage_key + + +@pytest.mark.parametrize( + ("transfer_method", "expected_message"), + [ + (FileTransferMethod.LOCAL_FILE, "Upload file upload-file-id not found"), + (FileTransferMethod.TOOL_FILE, "Tool file tool-file-id not found"), + ], +) +def test_resolve_storage_key_raises_when_records_are_missing( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + expected_message: str, +) -> None: + runtime = _build_runtime() + record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id" + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + session.get.return_value = None + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match=expected_message): + runtime._resolve_storage_key(file=file) + + +def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + runtime = _build_runtime() + + assert runtime.multimodal_send_format == "url" + + with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch.object(file_runtime.storage, "load", return_value=b"data") as mock_load: + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + +def test_bind_dify_workflow_file_runtime_registers_runtime(monkeypatch: pytest.MonkeyPatch) -> None: + set_runtime = MagicMock() + monkeypatch.setattr(file_runtime, "set_workflow_file_runtime", set_runtime) + + bind_dify_workflow_file_runtime() + + set_runtime.assert_called_once() + assert isinstance(set_runtime.call_args.args[0], DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py new file mode 100644 index 0000000000..c4bfb23272 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -0,0 +1,161 @@ +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.node_factory import DifyNodeFactory +from graphon.enums import BuiltinNodeTypes + + +class DummyNode: + def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = id + self.config = config + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self.kwargs = kwargs + + +class DummyCodeNode(DummyNode): + @classmethod + def default_code_providers(cls): + return () + + +class DummyTemplateTransformNode(DummyNode): + pass + + +class DummyHttpRequestNode(DummyNode): + pass + + +class DummyKnowledgeRetrievalNode(DummyNode): + pass + + +class DummyDocumentExtractorNode(DummyNode): + pass + + +class TestDifyNodeFactory: + @staticmethod + def _stub_node_resolution(monkeypatch, node_class): + monkeypatch.setattr( + "core.workflow.node_factory.resolve_workflow_node_class", + lambda **_kwargs: node_class, + ) + + def _factory(self, monkeypatch): + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100) + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u") + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key") + + run_context = build_dify_run_context( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + return DifyNodeFactory( + graph_init_params=SimpleNamespace(run_context=run_context), + graph_runtime_state=SimpleNamespace(), + ) + + def test_create_node_unknown_type(self, monkeypatch): + factory = self._factory(monkeypatch) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": "unknown"}}) + + def test_create_node_missing_mapping(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {}) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_missing_latest_class(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr( + "core.workflow.node_factory.get_node_type_classes_mapping", + lambda: {BuiltinNodeTypes.START: {"1": None}}, + ) + monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest") + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_selects_versioned_class(self, monkeypatch): + factory = self._factory(monkeypatch) + selected_versions: list[tuple[str, str]] = [] + + class DummyNodeV2(DummyNode): + pass + + def _get_mapping(): + selected_versions.append(("snapshot", "called")) + return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}} + + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}}) + + assert isinstance(node, DummyNodeV2) + assert node.id == "node-1" + assert selected_versions == [("snapshot", "called")] + + def test_create_node_code_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyCodeNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}}) + + assert isinstance(node, DummyCodeNode) + assert node.id == "node-1" + + def test_create_node_template_transform_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) + + assert isinstance(node, DummyTemplateTransformNode) + assert "jinja2_template_renderer" in node.kwargs + + def test_create_node_http_request_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyHttpRequestNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}) + + assert isinstance(node, DummyHttpRequestNode) + assert "http_request_config" in node.kwargs + + def test_create_node_knowledge_retrieval_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}) + + assert isinstance(node, DummyKnowledgeRetrievalNode) + assert node.kwargs == {} + + def test_create_node_document_extractor_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}}) + + assert isinstance(node, DummyDocumentExtractorNode) + assert "unstructured_api_config" in node.kwargs diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py new file mode 100644 index 0000000000..82552470a9 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes + + +class TestObservabilityLayerExtras: + def test_init_tracer_enabled_sets_tracer(self, monkeypatch): + tracer = object() + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer) + + layer = ObservabilityLayer() + + assert layer._is_disabled is False + assert layer._tracer is tracer + + def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + def _raise(*_args, **_kwargs): + raise RuntimeError("tracer init failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + assert layer._tracer is None + assert "Failed to get OpenTelemetry tracer" in caplog.text + + def test_init_tracer_disables_when_otel_disabled(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + + def test_get_parser_uses_registry_when_node_type_matches(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL)) + + assert parser is layer._parsers[BuiltinNodeTypes.TOOL] + + def test_get_parser_defaults_when_node_type_missing(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=None)) + + assert parser is layer._default_parser + + def test_on_graph_start_clears_contexts(self): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_start() + + assert layer._node_contexts == {} + + def test_on_event_is_noop(self): + layer = ObservabilityLayer() + + layer.on_event(object()) + + def test_on_graph_end_clears_unfinished_contexts(self, caplog): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_end(error=None) + + assert layer._node_contexts == {} + assert "node spans were not properly ended" in caplog.text + + def test_on_node_run_start_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + layer._tracer = None + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object()) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self): + layer = ObservabilityLayer() + layer._is_disabled = False + calls: list[str] = [] + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called")) + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert calls == [] + + def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + def _raise(*_args, **_kwargs): + raise RuntimeError("start failed") + + layer._tracer = SimpleNamespace(start_span=_raise) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert "Failed to create OpenTelemetry span for node" in caplog.text + + def test_on_node_run_end_without_context_noop(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None) + + assert "exec" in layer._node_contexts + + def test_on_node_run_end_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_calls_span_end(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + ended: list[str] = [] + + class _Parser: + def parse(self, **_kwargs): + return None + + span = SimpleNamespace(end=lambda: ended.append("ended")) + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert ended == ["ended"] + assert "exec" not in layer._node_contexts + + def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + class _Parser: + def parse(self, **_kwargs): + return None + + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token") + + def _raise(*_args, **_kwargs): + raise RuntimeError("detach failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert "Failed to detach OpenTelemetry token" in caplog.text + assert "exec" not in layer._node_contexts + + def test_on_node_run_start_and_end_creates_span(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + + span = SimpleNamespace(end=lambda: None) + tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span) + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token") + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None) + + layer._tracer = tracer + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + + layer.on_node_run_start(node) + assert "exec" in layer._node_contexts + + layer.on_node_run_end(node, error=None) + assert "exec" not in layer._node_contexts diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py new file mode 100644 index 0000000000..9863f34aba --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.entities.pause_reason import SchedulingPause +from graphon.entities.workflow_node_execution import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionStatus, + WorkflowType, +) +from graphon.graph_events.graph import ( + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.graph_events.node import ( + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool + + +class _RepoRecorder: + def __init__(self) -> None: + self.saved: list[object] = [] + self.saved_exec_data: list[object] = [] + + def save(self, entity): + self.saved.append(entity) + + def save_execution_data(self, entity): + self.saved_exec_data.append(entity) + + +def _naive_utc_now() -> datetime: + return datetime.now(UTC).replace(tzinfo=None) + + +def _make_layer( + system_variables: list | None = None, + *, + extras: dict | None = None, + trace_manager: object | None = None, +): + system_variables = system_variables or build_system_variables( + workflow_execution_id="run-id", + conversation_id="conv-id", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) + read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) + + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=SimpleNamespace(app_id="app", tenant_id="tenant"), + inputs={"foo": "bar"}, + files=[], + user_id="user", + stream=False, + invoke_from=None, + trace_manager=None, + workflow_execution_id="run-id", + extras=extras or {}, + call_depth=0, + ) + + workflow_info = PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={"nodes": [], "edges": []}, + ) + + workflow_execution_repo = _RepoRecorder() + workflow_node_execution_repo = _RepoRecorder() + + layer = WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=workflow_info, + workflow_execution_repository=workflow_execution_repo, + workflow_node_execution_repository=workflow_node_execution_repo, + trace_manager=trace_manager, + ) + layer.initialize(read_only_state, command_channel=None) + + return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state + + +class TestWorkflowPersistenceLayer: + def test_on_graph_start_resets_state(self): + layer, _, _, _ = _make_layer() + layer._workflow_execution = object() + layer._node_execution_cache["cached"] = object() + layer._node_snapshots["cached"] = object() + layer._node_sequence = 9 + + layer.on_graph_start() + + assert layer._workflow_execution is None + assert layer._node_execution_cache == {} + assert layer._node_snapshots == {} + assert layer._node_sequence == 0 + + def test_get_execution_id_requires_system_variable(self): + layer, _, _, _ = _make_layer(build_system_variables()) + + with pytest.raises(ValueError, match="workflow_execution_id must be provided"): + layer._get_execution_id() + + def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch): + layer, _, _, _ = _make_layer() + + monkeypatch.setattr( + "core.workflow.workflow_entry.WorkflowEntry.handle_special_values", + lambda inputs: inputs, + ) + + inputs = layer._prepare_workflow_inputs() + + assert "sys.conversation_id" not in inputs + assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id" + + def test_fail_running_node_executions_marks_failed(self): + layer, _, node_repo, _ = _make_layer() + + execution = WorkflowNodeExecution( + id="exec-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[execution.id] = execution + + layer._fail_running_node_executions(error_message="boom") + + assert execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_repo.saved + + def test_handle_graph_run_started_saves_execution(self): + layer, exec_repo, _, _ = _make_layer() + + layer._handle_graph_run_started() + + assert exec_repo.saved + + def test_handle_graph_run_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 3 + runtime_state.node_run_steps = 2 + runtime_state.outputs = {"out": "v"} + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.SUCCEEDED + assert saved.total_tokens == 3 + assert saved.total_steps == 2 + + def test_handle_graph_run_partial_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 5 + runtime_state.node_run_steps = 4 + runtime_state._graph_execution = SimpleNamespace(exceptions_count=2) + + layer._handle_graph_run_partial_succeeded( + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2) + ) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert saved.exceptions_count == 2 + assert saved.total_tokens == 5 + + def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self): + trace_tasks: list[object] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager) + layer._handle_graph_run_started() + + running = WorkflowNodeExecution( + id="node-exec", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[running.id] = running + + layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1)) + + assert node_repo.saved + assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED + assert trace_tasks + + def test_handle_graph_run_aborted_sets_status(self): + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + + layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.STOPPED + assert saved.error_message + + def test_handle_graph_run_paused_updates_outputs(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 7 + runtime_state.node_run_steps = 5 + + layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PAUSED + assert saved.outputs == {"pause": True} + assert saved.finished_at is None + + def test_handle_node_started_and_retry(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + predecessor_node_id="prev", + in_iteration_id="iter", + in_loop_id="loop", + ) + layer._handle_node_started(start_event) + + assert node_repo.saved + assert "exec" in layer._node_execution_cache + assert layer._node_snapshots["exec"].node_id == "node" + + retry_event = NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + layer._handle_node_retry(retry_event) + assert node_repo.saved_exec_data + + def test_handle_node_result_events_update_execution(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={}) + success_event = NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + node_run_result=result, + ) + layer._handle_node_succeeded(success_event) + + failed_event = NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="boom", + node_run_result=result, + ) + layer._handle_node_failed(failed_event) + + exception_event = NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="err", + node_run_result=result, + ) + layer._handle_node_exception(exception_event) + + assert node_repo.saved_exec_data + + def test_handle_node_pause_requested_skips_outputs(self): + layer, _, _, _ = _make_layer() + layer._handle_graph_run_started() + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + domain_execution = layer._node_execution_cache["exec"] + domain_execution.inputs = {"old": True} + + result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={}) + pause_event = NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + reason=SchedulingPause(message="pause"), + node_run_result=result, + ) + layer._handle_node_pause_requested(pause_event) + + assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED + assert domain_execution.inputs == {"old": True} + + def test_get_node_execution_raises_for_missing(self): + layer, _, _, _ = _make_layer() + with pytest.raises(ValueError, match="Node execution not found"): + layer._get_node_execution("missing") + + def test_get_workflow_execution_raises_when_uninitialized(self): + layer, _, _, _ = _make_layer() + + with pytest.raises(ValueError, match="workflow execution not initialized"): + layer._get_workflow_execution() + + def test_next_node_sequence_increments(self): + layer, _, _, _ = _make_layer() + assert layer._next_node_sequence() == 1 + assert layer._next_node_sequence() == 2 + + def test_on_graph_end_is_noop(self): + layer, _, _, _ = _make_layer() + + assert layer.on_graph_end(error=None) is None + + def test_on_event_dispatches_to_all_known_handlers(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_graph_run_started = _record("started") + layer._handle_graph_run_succeeded = _record("succeeded") + layer._handle_graph_run_partial_succeeded = _record("partial") + layer._handle_graph_run_failed = _record("failed") + layer._handle_graph_run_aborted = _record("aborted") + layer._handle_graph_run_paused = _record("paused") + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + layer._handle_node_succeeded = _record("node_succeeded") + layer._handle_node_failed = _record("node_failed") + layer._handle_node_exception = _record("node_exception") + layer._handle_node_pause_requested = _record("node_paused") + + node_result = NodeRunResult() + now = _naive_utc_now() + events = [ + GraphRunStartedEvent(), + GraphRunSucceededEvent(outputs={"ok": True}), + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1), + GraphRunFailedEvent(error="boom", exceptions_count=1), + GraphRunAbortedEvent(reason="stop", outputs={"x": 1}), + GraphRunPausedEvent(outputs={"pause": True}), + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + ), + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + error="retry", + retry_index=1, + ), + NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + node_run_result=node_result, + ), + NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="failed", + node_run_result=node_result, + ), + NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="error", + node_run_result=node_result, + ), + NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + reason=SchedulingPause(message="pause"), + node_run_result=node_result, + ), + ] + expected_order = [ + "started", + "succeeded", + "partial", + "failed", + "aborted", + "paused", + "node_started", + "node_retry", + "node_succeeded", + "node_failed", + "node_exception", + "node_paused", + ] + + for event in events: + layer.on_event(event) + + assert called == expected_order + + def test_on_event_dispatches_retry_before_started_for_retry_event(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + + layer.on_event( + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + ) + + assert called == ["node_retry"] + + def test_enqueue_trace_task_skips_when_disabled(self): + trace_tasks: list[object] = [] + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + assert exec_repo.saved + assert not trace_tasks diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 3759b6aa37..7b433ab57b 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -28,10 +28,7 @@ def mock_model_instance(mocker): def mock_model_manager(mocker, mock_model_instance): manager = mocker.MagicMock() manager.get_default_model_instance.return_value = mock_model_instance - mocker.patch( - "core.base.tts.app_generator_tts_publisher.ModelManager", - return_value=manager, - ) + mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager) return manager @@ -64,16 +61,14 @@ class TestInvoiceTTS: [None, "", " "], ) def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): - result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + result = _invoice_tts(text, mock_model_instance, "voice1") assert result is None mock_model_instance.invoke_tts.assert_not_called() def test_invoice_tts_valid_text(self, mock_model_instance): - result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + result = _invoice_tts(" hello ", mock_model_instance, "voice1") mock_model_instance.invoke_tts.assert_called_once_with( content_text="hello", - user="responding_tts", - tenant_id="tenant", voice="voice1", ) assert result == [b"audio1", b"audio2"] @@ -307,8 +302,8 @@ class TestAppGeneratorTTSPublisher: publisher.executor = MagicMock() from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import ( + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, @@ -342,8 +337,8 @@ class TestAppGeneratorTTSPublisher: publisher.executor = MagicMock() from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage chunk = LLMResultChunk( model="model", diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d5eeae912c..af992e4e9f 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -7,10 +7,11 @@ from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from core.workflow.file_reference import parse_file_reference +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.file import File +from graphon.file.enums import FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -428,11 +429,8 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): return fake_tool_file mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) - mocker.patch( - "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE - ) + mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", @@ -533,7 +531,6 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - tenant_id="t1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", @@ -664,6 +661,8 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + assert parse_file_reference(f.reference).storage_key is None + assert f.storage_key == "k" def test_get_upload_file_by_id_raises_when_missing(mocker): diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 43f582feb7..0b91d59953 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -4,8 +4,8 @@ import pytest from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType +from graphon.file import File +from graphon.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index 2e4f6d34fb..ef8f360dbf 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -4,8 +4,8 @@ from core.entities.execution_extra_content import ( HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 7a3d5e84ed..a0b2820157 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -16,9 +16,9 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 95d58757f1..fe2c226843 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -24,9 +24,9 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FieldModelSchema, @@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( "core.entities.provider_configuration.encrypter.encrypt_token", @@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None: with patch("core.entities.provider_configuration.db") as mock_db: mock_db.engine = Mock() mock_session_cls.return_value.__enter__.return_value = mock_session - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -434,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder: model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema + assert mock_factory_builder.call_count == 2 mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", @@ -449,6 +455,33 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: ) +def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: + configuration = _build_provider_configuration() + bound_runtime = Mock() + configuration.bind_model_runtime(bound_runtime) + + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with ( + patch( + "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory + ) as mock_factory_cls, + patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + ): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + assert mock_factory_cls.call_count == 2 + mock_factory_cls.assert_called_with(model_runtime=bound_runtime) + mock_factory_builder.assert_not_called() + + def test_get_provider_model_returns_none_when_model_not_found() -> None: configuration = _build_provider_configuration() fake_model = SimpleNamespace(model="other-model") @@ -475,7 +508,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -689,7 +722,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1034,7 +1067,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( @@ -1050,7 +1083,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"region": "us"} with _patched_session(session): - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1540,7 +1575,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1662,7 +1697,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index c5bfd05a1e..a159d3ad4d 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -8,7 +8,7 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index deebf41320..bb6e40e224 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,4 +1,4 @@ -from dify_graph.file import File, FileTransferMethod, FileType +from graphon.file import File, FileTransferMethod, FileType def test_file(): @@ -15,18 +15,17 @@ def test_file(): storage_key="test-storage-key", url="https://example.com/image.png", ) - assert file.tenant_id == "test-tenant-id" assert file.type == FileType.IMAGE assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" assert file.filename == "image.png" assert file.extension == ".png" assert file.mime_type == "image/png" assert file.size == 67 -def test_file_model_validate_with_legacy_fields(): - """Test `File` model can handle data containing compatibility fields.""" +def test_file_model_validate_accepts_legacy_tenant_id(): data = { "id": "test-file", "tenant_id": "test-tenant-id", @@ -45,10 +44,8 @@ def test_file_model_validate_with_legacy_fields(): "datasource_file_id": "datasource-file-789", } - # Should be able to create `File` object without raising an exception file = File.model_validate(data) - # The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes. - # Instead, check it does not expose unrecognized legacy fields (should raise on getattr). - for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"): - assert not hasattr(file, legacy_field) + assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" + assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 46c9dc6f9c..6ed9ddb476 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -16,20 +16,20 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultWithStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 5b7640696f..b3a5885814 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -6,14 +6,14 @@ import pytest from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @pytest.fixture def mock_model_instance(self): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_default_model_instance.return_value = instance mock_manager.return_value.get_model_instance.return_value = instance @@ -98,7 +98,7 @@ class TestLLMGenerator: assert questions[0] == "Question 1?" def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] @@ -528,7 +528,7 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_model_instance.return_value = instance mock_response = MagicMock() diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f982765b1a..bfb1fde502 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -18,7 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 5ecfe01808..f459250b8e 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -6,7 +6,7 @@ from uuid import uuid4 import pytest from core.memory.token_buffer_memory import TokenBufferMemory -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py new file mode 100644 index 0000000000..249ecb5006 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -0,0 +1,420 @@ +from unittest.mock import Mock + +import pytest + +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _build_model(model: str, model_type: ModelType) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _build_provider( + *, + provider: str, + provider_name: str, + supported_model_types: list[ModelType], + models: list[AIModelEntity] | None = None, + provider_credential_schema: ProviderCredentialSchema | None = None, + model_credential_schema: ModelCredentialSchema | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + provider_name=provider_name, + label=I18nObject(en_US=provider_name or provider), + supported_model_types=supported_model_types, + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + provider_credential_schema=provider_credential_schema, + model_credential_schema=model_credential_schema, + ) + + +class _FakeModelRuntime: + def __init__(self, providers: list[ProviderEntity]) -> None: + self._providers = providers + self.validate_provider_credentials = Mock() + self.validate_model_credentials = Mock() + self.get_model_schema = Mock() + self.get_provider_icon = Mock() + + def fetch_model_providers(self) -> list[ProviderEntity]: + return self._providers + + +def test_model_provider_factory_resolves_runtime_provider_name() -> None: + provider = ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_resolves_canonical_short_name_independent_of_provider_order() -> None: + providers = [ + ProviderEntity( + provider="acme/openai/openai", + provider_name="", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_requires_runtime() -> None: + with pytest.raises(ValueError, match="model_runtime is required"): + ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + + +def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + result = factory.get_providers() + + assert list(result) == providers + assert result is not providers + + +def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None: + provider = _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + result = factory.get_provider_schema("openai") + + assert result is provider + + +def test_model_provider_factory_raises_for_unknown_provider() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Invalid provider: anthropic"): + factory.get_model_provider("anthropic") + + +def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ), + _build_provider( + provider="langgenius/cohere/cohere", + provider_name="cohere", + supported_model_types=[ModelType.RERANK], + models=[_build_model("rerank-v3", ModelType.RERANK)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai", model_type=ModelType.LLM) + + assert len(results) == 1 + assert results[0].provider == "langgenius/openai/openai" + assert [model.model for model in results[0].models] == ["gpt-4o-mini"] + + +def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + models=[_build_model("gpt-4o-mini", ModelType.LLM)], + ), + _build_provider( + provider="langgenius/elevenlabs/elevenlabs", + provider_name="elevenlabs", + supported_model_types=[ModelType.TTS], + models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(model_type=ModelType.TTS) + + assert len(results) == 1 + assert results[0].provider == "langgenius/elevenlabs/elevenlabs" + assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"] + + +def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai") + + assert len(results) == 1 + assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"] + + +def test_model_provider_factory_validates_provider_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + provider_credential_schema=ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ] + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.provider_credentials_validate( + provider="openai", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_provider_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"}) + + +def test_model_provider_factory_validates_model_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + model_credential_schema=ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_model_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + runtime.get_model_schema.return_value = "schema" + runtime.get_provider_icon.return_value = (b"icon", "image/png") + factory = ModelProviderFactory(model_runtime=runtime) + + assert ( + factory.get_model_schema( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials=None, + ) + == "schema" + ) + assert factory.get_provider_icon("openai", "icon_small", "en_US") == (b"icon", "image/png") + runtime.get_model_schema.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + runtime.get_provider_icon.assert_called_once_with( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + +@pytest.mark.parametrize( + ("model_type", "expected_type"), + [ + (ModelType.LLM, LargeLanguageModel), + (ModelType.TEXT_EMBEDDING, TextEmbeddingModel), + (ModelType.RERANK, RerankModel), + (ModelType.SPEECH2TEXT, Speech2TextModel), + (ModelType.MODERATION, ModerationModel), + (ModelType.TTS, TTSModel), + ], +) +def test_model_provider_factory_builds_model_type_instances( + model_type: ModelType, + expected_type: type[object], +) -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] + ) + ) + + instance = factory.get_model_type_instance("openai", model_type) + + assert isinstance(instance, expected_type) + + +def test_model_provider_factory_rejects_unsupported_model_type() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Unsupported model type: unsupported"): + factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index e61cde22e7..3a97ad5c5d 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index dfd61acfa7..c2324fdec4 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -34,8 +34,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -396,14 +396,14 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = ["node1"] + repo.get_by_workflow_execution.return_value = ["node1"] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) result = trace_instance.get_workflow_node_executions(trace_info) assert result == ["node1"] - repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + repo.get_by_workflow_execution.assert_called_once_with(workflow_execution_id=trace_info.workflow_run_id) def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index 763fc90710..fa885e9320 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -24,8 +24,8 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py index 1cee2f5b68..4ce9e22fd7 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -254,7 +254,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac node1.id = "n1" node1.error = None - repo.get_by_workflow_run.return_value = [node1] + repo.get_by_workflow_execution.return_value = [node1] with patch.object(trace_instance, "get_service_account_with_tenant"): trace_instance.workflow_trace(info) diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 0ff135562c..fdf66d4d40 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -25,7 +25,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -174,7 +174,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = None repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -244,7 +244,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) @@ -680,7 +680,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index f656f7435f..e89359c25b 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -21,7 +21,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -184,7 +184,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + repo.get_by_workflow_execution.return_value = [node_llm, node_other, node_retrieval] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -255,7 +255,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) @@ -565,7 +565,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm] + repo.get_by_workflow_execution.return_value = [node_llm] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index cccedaa08c..7ff6f7dcfd 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index b2cb7d5109..6625cb719f 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -18,7 +18,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -199,7 +199,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -253,7 +253,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) @@ -657,7 +657,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index a0b6d52720..6113e5c6c8 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -25,8 +25,8 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index f259e4639f..265652381c 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -14,8 +14,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -413,7 +413,7 @@ class TestTencentDataTrace: with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" ) as mock_repo: - mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 49d6b698ef..4b925390d9 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from dify_graph.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/test_lookup_helpers.py b/api/tests/unit_tests/core/ops/test_lookup_helpers.py new file mode 100644 index 0000000000..86aa68643d --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_lookup_helpers.py @@ -0,0 +1,554 @@ +"""Unit tests for lookup helper functions in core.ops.ops_trace_manager. + +Covers: +- _lookup_app_and_workspace_names +- _lookup_credential_name +- _lookup_llm_credential_info +- TraceTask._get_user_id_from_metadata +""" + +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None): + """Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db' + and 'core.ops.ops_trace_manager.Session'. + + Provide either scalar_side_effect (list, for multiple calls) or + scalar_return_value (single value). + """ + mock_db = MagicMock() + mock_db.engine = MagicMock() + + session = MagicMock() + if scalar_side_effect is not None: + session.scalar.side_effect = scalar_side_effect + else: + session.scalar.return_value = scalar_return_value + + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=session) + cm.__exit__ = MagicMock(return_value=False) + + return mock_db, cm, session + + +# --------------------------------------------------------------------------- +# _lookup_app_and_workspace_names +# --------------------------------------------------------------------------- + + +class TestLookupAppAndWorkspaceNames: + """Tests for _lookup_app_and_workspace_names(app_id, tenant_id).""" + + def test_both_found(self): + """Returns (app_name, workspace_name) when both records exist.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "MyWorkspace" + + def test_app_only_found(self): + """Returns (app_name, '') when tenant record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "" + + def test_tenant_only_found(self): + """Returns ('', workspace_name) when app record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "MyWorkspace" + + def test_neither_found(self): + """Returns ('', '') when both DB lookups return None.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "" + + def test_none_inputs_skips_db(self): + """Returns ('', '') immediately when both IDs are None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, None) + + mock_session_cls.assert_not_called() + assert app_name == "" + assert workspace_name == "" + + def test_app_id_none_only_queries_tenant(self): + """When app_id is None, only the tenant query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456") + + assert app_name == "" + assert workspace_name == "OnlyWorkspace" + assert session.scalar.call_count == 1 + + def test_tenant_id_none_only_queries_app(self): + """When tenant_id is None, only the app query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None) + + assert app_name == "OnlyApp" + assert workspace_name == "" + assert session.scalar.call_count == 1 + + +# --------------------------------------------------------------------------- +# _lookup_credential_name +# --------------------------------------------------------------------------- + + +class TestLookupCredentialName: + """Tests for _lookup_credential_name(credential_id, provider_type).""" + + @pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"]) + def test_known_provider_types_return_name(self, provider_type): + """Each valid provider_type results in a DB query and returns the credential name.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-123", provider_type) + + assert result == "CredentialA" + session.scalar.assert_called_once() + + def test_credential_not_found_returns_empty_string(self): + """Returns '' when DB yields None for the given credential_id.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-999", "api") + + assert result == "" + + def test_invalid_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately for an unrecognised provider_type — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", "unknown_type") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_credential_id_returns_empty_string_without_db(self): + """Returns '' immediately when credential_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name(None, "api") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately when provider_type is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", None) + + mock_session_cls.assert_not_called() + assert result == "" + + def test_builtin_and_plugin_map_to_same_model(self): + """Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import BuiltinToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider + assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider + + def test_api_maps_to_api_tool_provider(self): + """'api' maps to ApiToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import ApiToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider + + def test_workflow_maps_to_workflow_tool_provider(self): + """'workflow' maps to WorkflowToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import WorkflowToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider + + def test_mcp_maps_to_mcp_tool_provider(self): + """'mcp' maps to MCPToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import MCPToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider + + +# --------------------------------------------------------------------------- +# _lookup_llm_credential_info +# --------------------------------------------------------------------------- + + +class TestLookupLlmCredentialInfo: + """Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type).""" + + def _provider_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def _model_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def test_model_level_credential_found(self): + """Returns model-level credential_id and name when ProviderModel has a credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id="model-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ModelCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "model-cred-id" + assert cred_name == "ModelCredName" + + def test_provider_level_fallback_when_no_model_credential(self): + """Falls back to provider-level credential when ProviderModel has no credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + model_record = self._model_record(credential_id=None) + + # scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ProvCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_provider_level_fallback_when_no_model_record(self): + """Falls back to provider-level credential when no ProviderModel row exists.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_no_model_arg_uses_provider_level_only(self): + """When model is None, skips ProviderModel query and uses provider credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel + mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None) + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + assert session.scalar.call_count == 2 + + def test_provider_not_found_returns_none_and_empty(self): + """Returns (None, '') when Provider record does not exist.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_none_tenant_id_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when tenant_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_none_provider_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when provider is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_db_error_on_outer_query_returns_none_and_empty(self): + """Returns (None, '') and logs a warning when the outer DB query raises.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, session = _make_db_and_session_patches() + session.scalar.side_effect = Exception("DB connection failed") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_credential_name_lookup_failure_returns_id_with_empty_name(self): + """When credential name sub-query fails, returns cred_id but '' for name.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # Provider found, no model record, then name lookup raises + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, None, Exception("deleted")] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "" + + def test_no_credential_on_provider_or_model_returns_none_id(self): + """Returns (None, '') when neither provider nor model has a credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id=None) + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + +# --------------------------------------------------------------------------- +# TraceTask._get_user_id_from_metadata +# --------------------------------------------------------------------------- + + +class TestGetUserIdFromMetadata: + """Tests for TraceTask._get_user_id_from_metadata(metadata). + + Pure dict logic — no DB access required. + """ + + @pytest.fixture + def get_user_id(self): + """Return the classmethod under test.""" + from core.ops.ops_trace_manager import TraceTask + + return TraceTask._get_user_id_from_metadata + + def test_from_end_user_id_has_highest_priority(self, get_user_id): + """from_end_user_id takes precedence over all other keys.""" + metadata = { + "from_end_user_id": "eu-abc", + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "end_user:eu-abc" + + def test_from_account_id_used_when_no_end_user(self, get_user_id): + """from_account_id is used when from_end_user_id is absent.""" + metadata = { + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_user_id_used_when_no_end_user_or_account(self, get_user_id): + """user_id is used when both higher-priority keys are absent.""" + metadata = {"user_id": "u-123"} + assert get_user_id(metadata) == "user:u-123" + + def test_returns_anonymous_when_all_keys_absent(self, get_user_id): + """Returns 'anonymous' when metadata has none of the expected keys.""" + assert get_user_id({}) == "anonymous" + + def test_empty_string_end_user_id_is_skipped(self, get_user_id): + """Empty string for from_end_user_id is falsy and falls through to next key.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "acc-xyz", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_empty_string_account_id_is_skipped(self, get_user_id): + """Empty string for from_account_id is falsy and falls through to user_id.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "u-123", + } + assert get_user_id(metadata) == "user:u-123" + + def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id): + """Empty string for user_id is falsy, so 'anonymous' is returned.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "", + } + assert get_user_id(metadata) == "anonymous" + + def test_only_from_end_user_id_present(self, get_user_id): + """Minimal case: only from_end_user_id present.""" + assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only" + + def test_irrelevant_keys_do_not_interfere(self, get_user_id): + """Extra metadata keys have no effect on the result.""" + metadata = {"invoke_from": "web", "app_id": "a1"} + assert get_user_id(metadata) == "anonymous" diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py index 7660967183..ad9d0846be 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -130,7 +130,7 @@ class TestWorkflowTraceWithoutMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, @@ -262,7 +262,7 @@ class TestWorkflowTraceWithMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 2d325ccb0e..f81806c941 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -86,6 +86,7 @@ def make_message_data(**overrides): created_at = datetime(2025, 2, 20, 12, 0, 0) base = { "id": "msg-id", + "app_id": "app-id", "conversation_id": "conv-id", "created_at": created_at, "updated_at": created_at + timedelta(seconds=3), @@ -182,6 +183,9 @@ class DummySessionContext: def __exit__(self, exc_type, exc_val, exc_tb): return False + def execute(self, *args, **kwargs): + return self + def scalar(self, *args, **kwargs): if self._index >= len(self._values): return None @@ -189,6 +193,12 @@ class DummySessionContext: self._index += 1 return value + def scalars(self, *args, **kwargs): + return self + + def all(self): + return [] + @pytest.fixture(autouse=True) def patch_provider_map(monkeypatch): @@ -454,7 +464,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db): def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] - execution = SimpleNamespace(id_="run-id") + execution = SimpleNamespace(id_="run-id", total_tokens=0) task = TraceTask( trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" ) diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 0000000000..a4903054e0 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,194 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at +least one consumer is active: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When neither is active, tasks are silently dropped to avoid unnecessary work. + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8057bbbad5..8987b6682c 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -589,7 +589,7 @@ class TestWorkflowTrace: nodes = [] repo = MagicMock() - repo.get_by_workflow_run.return_value = nodes + repo.get_by_workflow_execution.return_value = nodes mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py new file mode 100644 index 0000000000..7491e79f30 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch + +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly + + +def test_plugin_model_assembly_reuses_single_runtime_across_views(): + runtime = Mock(name="runtime") + provider_factory = Mock(name="provider_factory") + provider_manager = Mock(name="provider_manager") + model_manager = Mock(name="model_manager") + + with ( + patch( + "core.plugin.impl.model_runtime_factory.create_plugin_model_runtime", + return_value=runtime, + ) as mock_runtime_factory, + patch( + "core.plugin.impl.model_runtime_factory.ModelProviderFactory", + return_value=provider_factory, + ) as mock_provider_factory_cls, + patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls, + patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls, + ): + assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1") + + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + + mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) + mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py new file mode 100644 index 0000000000..c24d3ac012 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.entities.request import RequestInvokeSummary +from graphon.model_runtime.entities.message_entities import UserPromptMessage + + +def test_system_model_helpers_forward_user_id(): + with ( + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens", + return_value=4096, + ) as mock_max_tokens, + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens", + return_value=7, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096 + assert ( + PluginModelBackwardsInvocation.get_prompt_tokens( + "tenant-1", + [UserPromptMessage(content="hello")], + user_id="user-1", + ) + == 7 + ) + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="user-1", + ) + + +def test_invoke_summary_uses_same_user_scope_for_token_helpers(): + tenant = SimpleNamespace(id="tenant-1") + payload = RequestInvokeSummary(text="short", instruction="keep it concise") + + with ( + patch.object( + PluginModelBackwardsInvocation, + "get_system_model_max_tokens", + return_value=100, + ) as mock_max_tokens, + patch.object( + PluginModelBackwardsInvocation, + "get_prompt_tokens", + return_value=10, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short" + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="short")], + user_id="user-1", + ) diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py new file mode 100644 index 0000000000..68aa130518 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -0,0 +1,506 @@ +"""Unit tests for the plugin-backed model runtime adapter.""" + +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import Mock, sentinel + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl import model_runtime as model_runtime_module +from core.plugin.impl.model import PluginModelClient +from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +def _build_model_schema() -> AIModelEntity: + return AIModelEntity( + model="gpt-4o-mini", + label=I18nObject(en_US="GPT-4o mini"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +class TestPluginModelRuntime: + """Validate the adapter keeps plugin-specific routing out of the runtime port.""" + + def test_fetch_model_providers_returns_runtime_entities(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert providers[0].provider_name == "openai" + assert providers[0].label.en_US == "OpenAI" + client.fetch_model_providers.assert_called_once_with("tenant") + + def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="acme/openai/openai", + plugin_id="acme/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + provider_aliases = {provider.provider: provider.provider_name for provider in providers} + assert provider_aliases["acme/openai/openai"] == "" + assert provider_aliases["langgenius/openai/openai"] == "openai" + + def test_fetch_model_providers_keeps_google_alias_on_canonical_gemini_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="google", + tenant_id="tenant", + plugin_unique_identifier="langgenius/gemini/google", + plugin_id="langgenius/gemini", + declaration=ProviderEntity( + provider="google", + label=I18nObject(en_US="Google"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert providers[0].provider == "langgenius/gemini/google" + assert providers[0].provider_name == "google" + + def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.validate_provider_credentials( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + client.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + credentials={"api_key": "secret"}, + ) + + def test_invoke_llm_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + def test_invoke_llm_rejects_per_call_user_override(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + + with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): + runtime.invoke_llm( # type: ignore[call-arg] + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + user_id="request-user", + ) + + client.invoke_llm.assert_not_called() + + def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_tts.return_value = iter([b"chunk"]) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + + result = runtime.invoke_tts( + provider="langgenius/openai/openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + assert list(result) == [b"chunk"] + client.invoke_tts.assert_called_once_with( + tenant_id="tenant", + user_id=None, + plugin_id="langgenius/openai", + provider="openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.fetch_model_providers() + runtime.fetch_model_providers() + + client.fetch_model_providers.assert_called_once_with("tenant") + + +def test_create_plugin_model_runtime_without_user_context() -> None: + runtime = create_plugin_model_runtime(tenant_id="tenant") + + assert runtime.user_id is None + + +def test_plugin_model_runtime_requires_client() -> None: + with pytest.raises(ValueError, match="client is required"): + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + + +def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=schema.model_dump_json()), + delete=Mock(), + setex=Mock(), + ), + ) + + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + client.get_model_schema.assert_not_called() + + +def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + delete = Mock() + setex = Mock() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value="not-json"), + delete=delete, + setex=setex, + ), + ) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) + client.get_model_schema.return_value = schema + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + delete.assert_called_once() + client.get_model_schema.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model_type=ModelType.LLM.value, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + setex.assert_called_once() + + +def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert ( + runtime.get_llm_num_tokens( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + prompt_messages=[], + tools=None, + ) + == 0 + ) + client.get_llm_num_tokens.assert_not_called() + + +def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=I18nObject(en_US="logo.svg"), + icon_small_dark=I18nObject(en_US="logo-dark.png"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + fetch_asset = Mock(return_value=b"") + monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + icon_bytes, mime_type = runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + assert icon_bytes == b"" + assert mime_type == "image/svg+xml" + fetch_asset.assert_called_once_with(tenant_id="tenant", id="logo.svg") + + +def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + with pytest.raises(ValueError, match="does not have small dark icon"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small_dark", + lang="en_US", + ) + + with pytest.raises(ValueError, match="Unsupported icon type"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_large", + lang="en_US", + ) + + +def test_get_schema_cache_key_is_stable_across_credential_order() -> None: + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + + first = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"b": "2", "a": "1"}, + ) + second = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1", "b": "2"}, + ) + + assert first == second + + +def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: + first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + + first = first_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + second = second_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert first != second + + +def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + user_key = user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert tenant_key != user_key + assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key + + +def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + empty_user_key = empty_user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + + assert tenant_key != empty_user_key + assert empty_user_key.endswith(":") + assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key + + +def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" + + with pytest.raises(ValueError, match="Invalid provider"): + runtime._get_provider_schema("missing") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index b0b64a601b..f1c4c7e700 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -25,7 +25,7 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage, diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 4f038d4a5b..af86f917b1 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -26,6 +26,7 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.base import BasePluginClient from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -36,14 +37,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: @@ -557,7 +558,7 @@ class TestPluginRuntimeErrorHandling: with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert - with pytest.raises(httpx.HTTPStatusError): + with pytest.raises(PluginDaemonInternalServerError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) def test_empty_data_response_error(self, plugin_client, mock_config): @@ -1808,8 +1809,8 @@ class TestPluginInstallerAdvanced: mock_response.raise_for_status = raise_for_status with patch("httpx.request", return_value=mock_response, autospec=True): - # Act & Assert - Should raise HTTPStatusError for 404 - with pytest.raises(httpx.HTTPStatusError): + # Act & Assert - Should raise PluginDaemonClientSideError for 404 + with pytest.raises(PluginDaemonClientSideError): installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en") def test_list_plugins_with_pagination(self, installer, mock_config): diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index c7e94aa4cf..4d4313dd84 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -6,8 +6,8 @@ from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3d08525aba..395d392127 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -9,8 +9,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, @@ -145,7 +145,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("dify_graph.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: + with patch("graphon.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 634703740c..803afa54d7 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -6,13 +6,13 @@ from core.app.entities.app_invoke_entities import ( from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 9fc300348a..5d865d934c 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,6 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index d379e3067a..9f9ea33695 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -4,14 +4,14 @@ from unittest.mock import MagicMock, patch import pytest from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -# from dify_graph.model_runtime.entities.message_entities import UserPromptMessage -# from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule -# from dify_graph.model_runtime.entities.provider_entities import ProviderEntity -# from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.entities.message_entities import UserPromptMessage +# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +# from graphon.model_runtime.entities.provider_entities import ProviderEntity +# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index e6d28224d7..0dc74b33df 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -18,7 +18,7 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py index 65ee62b8dd..c7a4265a95 100644 --- a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -211,3 +211,16 @@ class TestCleanProcessor: text = "[Text with (parens) and symbols](https://example.com)" expected = "[Text with (parens) and symbols](https://example.com)" assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_remove_urls_emails_preserves_markdown_image_links(self): + """Remove plain URLs and emails while preserving markdown image links.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)" + result = CleanProcessor.clean(text, process_rule) + + assert result == "Email and remove but keep ![diagram](https://example.com/image.png)" + + def test_filter_string_returns_input_text(self): + """Test filter_string passthrough behavior.""" + processor = CleanProcessor() + assert processor.filter_string("raw text") == "raw text" diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py new file mode 100644 index 0000000000..1f3247590c --- /dev/null +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -0,0 +1,246 @@ +from unittest.mock import MagicMock, patch + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError + + +def _doc(content: str) -> Document: + return Document(page_content=content) + + +class TestDataPostProcessor: + def test_init_sets_rerank_and_reorder_runners(self): + rerank_runner = object() + reorder_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock: + with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock: + processor = DataPostProcessor( + tenant_id="tenant-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_model={"config": "value"}, + weights={"weight": "value"}, + reorder_enabled=True, + ) + + assert processor.rerank_runner is rerank_runner + assert processor.reorder_runner is reorder_runner + rerank_mock.assert_called_once_with( + RerankMode.WEIGHTED_SCORE, + "tenant-1", + {"config": "value"}, + {"weight": "value"}, + ) + reorder_mock.assert_called_once_with(True) + + def test_invoke_applies_rerank_then_reorder(self): + original_documents = [_doc("doc-a")] + reranked_documents = [_doc("doc-b")] + reordered_documents = [_doc("doc-c")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = MagicMock() + processor.rerank_runner.run.return_value = reranked_documents + processor.reorder_runner = MagicMock() + processor.reorder_runner.run.return_value = reordered_documents + + result = processor.invoke( + query="how to test", + documents=original_documents, + score_threshold=0.3, + top_n=2, + query_type=QueryType.IMAGE_QUERY, + ) + + processor.rerank_runner.run.assert_called_once_with( + "how to test", + original_documents, + 0.3, + 2, + QueryType.IMAGE_QUERY, + ) + processor.reorder_runner.run.assert_called_once_with(reranked_documents) + assert result == reordered_documents + + def test_invoke_returns_original_documents_when_no_runner_is_configured(self): + documents = [_doc("doc-a"), _doc("doc-b")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = None + processor.reorder_runner = None + + assert processor.invoke(query="query", documents=documents) == documents + + def test_get_rerank_runner_for_weighted_score(self): + weights_config = { + "vector_setting": { + "vector_weight": 0.7, + "embedding_provider_name": "provider-x", + "embedding_model_name": "embedding-y", + }, + "keyword_setting": {"keyword_weight": 0.3}, + } + expected_runner = object() + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.WEIGHTED_SCORE, + tenant_id="tenant-1", + reranking_model=None, + weights=weights_config, + ) + + assert result is expected_runner + kwargs = factory_mock.call_args.kwargs + assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE + assert kwargs["tenant_id"] == "tenant-1" + assert kwargs["weights"].vector_setting.vector_weight == 0.7 + assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x" + assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y" + assert kwargs["weights"].keyword_setting.keyword_weight == 0.3 + + def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + reranking_model = { + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + } + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock: + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner" + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model=reranking_model, + weights=None, + ) + + assert result is None + model_mock.assert_called_once_with("tenant-1", reranking_model) + factory_mock.assert_not_called() + + def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + expected_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance): + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + }, + weights=None, + ) + + assert result is expected_runner + factory_mock.assert_called_once_with( + runner_type=RerankMode.RERANKING_MODEL, + rerank_model_instance=model_instance, + ) + + def test_get_rerank_runner_returns_none_for_unsupported_mode(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None + assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None + + def test_get_reorder_runner_by_flag(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert isinstance(processor._get_reorder_runner(True), ReorderRunner) + assert processor._get_reorder_runner(False) is None + + def test_get_rerank_model_instance_returns_none_when_config_is_missing(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + assert processor._get_rerank_model_instance("tenant-1", None) is None + + def test_get_rerank_model_instance_returns_none_for_incomplete_config(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) + + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + + def test_get_rerank_model_instance_success(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value + manager_instance.get_model_instance.return_value = model_instance + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is model_instance + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + manager_instance.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider-x", + model_type=ModelType.RERANK, + model="reranker-1", + ) + + def test_get_rerank_model_instance_handles_authorization_error(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value + manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + + +class TestReorderRunner: + def test_run_reorders_even_sized_document_list(self): + documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")] + + reordered = ReorderRunner().run(documents) + + assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"] + + def test_run_handles_odd_sized_and_empty_document_lists(self): + odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")] + runner = ReorderRunner() + + odd_reordered = runner.run(odd_documents) + + assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"] + assert runner.run([]) == [] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py new file mode 100644 index 0000000000..795a325a6b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -0,0 +1,414 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.keyword.jieba.jieba as jieba_module +from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default +from core.rag.models.document import Document + + +class _DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _Field: + def __init__(self, name: str): + self._name = name + + def __eq__(self, other): + return ("eq", self._name, other) + + def in_(self, values): + return ("in", self._name, tuple(values)) + + +class _FakeQuery: + def __init__(self): + self.where_calls: list[tuple] = [] + + def where(self, *conditions): + self.where_calls.append(conditions) + return self + + +class _FakeExecuteResult: + def __init__(self, segments: list[SimpleNamespace]): + self._segments = segments + + def scalars(self): + return self + + def all(self): + return self._segments + + +class _FakeSelect: + def __init__(self): + self.where_conditions: tuple | None = None + + def where(self, *conditions): + self.where_conditions = conditions + return self + + +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): + return SimpleNamespace( + data_source_type=data_source_type, + keyword_table_dict=keyword_table_dict, + keyword_table="", + ) + + +def _dataset(dataset_keyword_table=None, keyword_number=None): + return SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + keyword_number=keyword_number, + dataset_keyword_table=dataset_keyword_table, + ) + + +@pytest.fixture +def patched_runtime(monkeypatch): + session = MagicMock() + db = SimpleNamespace(session=session) + storage = MagicMock() + lock = MagicMock(return_value=_DummyLock()) + redis_client = SimpleNamespace(lock=lock) + + monkeypatch.setattr(jieba_module, "db", db) + monkeypatch.setattr(jieba_module, "storage", storage) + monkeypatch.setattr(jieba_module, "redis_client", redis_client) + + return SimpleNamespace(session=session, storage=storage, lock=lock) + + +def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime): + dataset = _dataset(_dataset_keyword_table(), keyword_number=2) + keyword = Jieba(dataset) + handler = MagicMock() + handler.extract_keywords.return_value = {"kw1", "kw2"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + result = keyword.create( + [ + Document(page_content="alpha", metadata={"doc_id": "node-1"}), + SimpleNamespace(page_content="ignored", metadata=None), + ] + ) + + assert result is keyword + keyword._update_segment_keywords.assert_called_once() + call_args = keyword._update_segment_keywords.call_args.args + assert call_args[0] == "dataset-1" + assert call_args[1] == "node-1" + assert set(call_args[2]) == {"kw1", "kw2"} + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["kw1"] == {"node-1"} + assert saved_table["kw2"] == {"node-1"} + patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600) + + +def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + texts = [ + Document(page_content="extract-this", metadata={"doc_id": "node-1"}), + Document(page_content="use-manual", metadata={"doc_id": "node-2"}), + ] + keyword.add_texts(texts, keywords_list=[[], ["manual"]]) + + assert keyword._update_segment_keywords.call_count == 2 + first_call = keyword._update_segment_keywords.call_args_list[0].args + second_call = keyword._update_segment_keywords.call_args_list[1].args + assert set(first_call[2]) == {"auto"} + assert second_call[2] == ["manual"] + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1)) + handler = MagicMock() + handler.extract_keywords.return_value = {"from-extractor"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})]) + + handler.extract_keywords.assert_called_once_with("content", 1) + assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"} + + +def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + assert keyword.text_exists("node-1") is False + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + assert keyword.text_exists("node-2") is True + assert keyword.text_exists("node-x") is False + + +def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"]) + keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}}) + + +def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_not_called() + keyword._save_dataset_keyword_table.assert_called_once_with(None) + + +def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + document_id = _Field("document_id") + + keyword = Jieba(_dataset(_dataset_keyword_table())) + query_stmt = _FakeQuery() + patched_runtime.session.query.return_value = query_stmt + patched_runtime.session.execute.return_value = _FakeExecuteResult( + [ + SimpleNamespace( + index_node_id="node-2", + content="segment-content", + index_node_hash="hash-2", + document_id="doc-2", + dataset_id="dataset-1", + ) + ] + ) + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) + + documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) + + assert len(query_stmt.where_calls) == 2 + assert len(documents) == 1 + assert documents[0].page_content == "segment-content" + assert documents[0].metadata["doc_id"] == "node-2" + assert documents[0].metadata["doc_hash"] == "hash-2" + + +def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime): + db_keyword = _dataset_keyword_table(data_source_type="database") + file_keyword = _dataset_keyword_table(data_source_type="object_storage") + + keyword_db = Jieba(_dataset(db_keyword)) + keyword_db.delete() + patched_runtime.storage.delete.assert_not_called() + + keyword_file = Jieba(_dataset(file_keyword)) + keyword_file.delete() + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + assert patched_runtime.session.delete.call_count == 2 + assert patched_runtime.session.commit.call_count == 2 + + +def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="database") + keyword = Jieba(_dataset(dataset_keyword_table)) + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table + assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table + patched_runtime.session.commit.assert_called_once() + + +def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="file") + keyword = Jieba(_dataset(dataset_keyword_table)) + patched_runtime.storage.exists.return_value = True + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + patched_runtime.storage.save.assert_called_once() + save_args = patched_runtime.storage.save.call_args.args + assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt" + assert isinstance(save_args[1], bytes) + + +def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime): + existing = _dataset_keyword_table( + keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}} + ) + keyword = Jieba(_dataset(existing)) + assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]} + + missing_payload = _dataset_keyword_table(keyword_table_dict=None) + keyword_with_missing_payload = Jieba(_dataset(missing_payload)) + assert keyword_with_missing_payload._get_dataset_keyword_table() == {} + + +def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime): + created_tables: list[SimpleNamespace] = [] + + def _fake_dataset_keyword_table(**kwargs): + kwargs.setdefault("keyword_table", "") + kwargs.setdefault("keyword_table_dict", None) + table = SimpleNamespace(**kwargs) + created_tables.append(table) + return table + + keyword = Jieba(_dataset(dataset_keyword_table=None)) + monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table) + monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database") + + result = keyword._get_dataset_keyword_table() + + assert result == {} + assert len(created_tables) == 1 + assert created_tables[0].dataset_id == "dataset-1" + assert created_tables[0].data_source_type == "database" + assert '"index_id":"dataset-1"' in created_tables[0].keyword_table + patched_runtime.session.add.assert_called_once_with(created_tables[0]) + patched_runtime.session.commit.assert_called_once() + + +def test_add_and_delete_ids_from_keyword_table_helpers(): + keyword = Jieba(_dataset(_dataset_keyword_table())) + keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}} + + updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"]) + assert updated["kw1"] == {"node-1", "node-3"} + assert updated["kw3"] == {"node-3"} + + deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"]) + assert "kw3" not in deleted + assert "kw1" not in deleted + assert deleted["kw2"] == {"node-2"} + + +def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + handler = MagicMock() + handler.extract_keywords.return_value = ["kw-a", "kw-b"] + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + + ranked_ids = keyword._retrieve_ids_by_query( + {"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}}, + "query", + k=1, + ) + + assert ranked_ids == ["node-2"] + + +def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) + + keyword = Jieba(_dataset(_dataset_keyword_table())) + segment = SimpleNamespace(keywords=[]) + patched_runtime.session.scalar.return_value = segment + + keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"]) + + assert segment.keywords == ["kw1", "kw2"] + patched_runtime.session.add.assert_called_once_with(segment) + patched_runtime.session.commit.assert_called_once() + + patched_runtime.session.reset_mock() + patched_runtime.session.scalar.return_value = None + + keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"]) + + patched_runtime.session.add.assert_not_called() + patched_runtime.session.commit.assert_not_called() + + +def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.create_segment_keywords("node-1", ["kw"]) + keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"]) + keyword._save_dataset_keyword_table.assert_called_once() + + keyword._save_dataset_keyword_table.reset_mock() + keyword.update_segment_keywords_index("node-2", ["kw2"]) + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None) + second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None) + + keyword.multi_create_segment_keywords( + [ + {"segment": first_segment, "keywords": ["manual"]}, + {"segment": second_segment, "keywords": []}, + ] + ) + + assert first_segment.keywords == ["manual"] + assert second_segment.keywords == ["auto"] + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["manual"] == {"node-1"} + assert saved_table["auto"] == {"node-2"} + + +def test_set_orjson_default_and_dumps_with_sets(): + assert set(set_orjson_default({"a", "b"})) == {"a", "b"} + + with pytest.raises(TypeError, match="is not JSON serializable"): + set_orjson_default(("not", "a", "set")) + + payload = {"items": {"a", "b"}} + json_payload = dumps_with_sets(payload) + decoded = json.loads(json_payload) + assert set(decoded["items"]) == {"a", "b"} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py new file mode 100644 index 0000000000..a4586c141b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py @@ -0,0 +1,142 @@ +import sys +import types +from types import SimpleNamespace + +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +class _DummyTFIDF: + def __init__(self): + self.stop_words = set() + + @staticmethod + def extract_tags(sentence: str, top_k: int | None = 20, **kwargs): + return ["alpha_beta", "during", "gamma"] + + +def _install_fake_jieba_modules( + monkeypatch, + analyse_module: types.ModuleType, + jieba_attrs: dict[str, object] | None = None, + tfidf_module: types.ModuleType | None = None, +): + jieba_module = types.ModuleType("jieba") + jieba_module.__path__ = [] + if jieba_attrs: + for key, value in jieba_attrs.items(): + setattr(jieba_module, key, value) + + jieba_module.analyse = analyse_module + analyse_module.__package__ = "jieba" + + monkeypatch.setitem(sys.modules, "jieba", jieba_module) + monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module) + if tfidf_module is not None: + monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module) + else: + monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False) + + +def test_init_uses_existing_default_tfidf(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + default_tfidf = _DummyTFIDF() + analyse_module.default_tfidf = default_tfidf + + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert handler._tfidf is default_tfidf + assert handler._tfidf.stop_words == STOPWORDS + + +def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + class _TFIDFFactory(_DummyTFIDF): + pass + + analyse_module.TFIDF = _TFIDFFactory + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _TFIDFFactory) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + tfidf_module = types.ModuleType("jieba.analyse.tfidf") + + class _ImportedTFIDF(_DummyTFIDF): + pass + + tfidf_module.TFIDF = _ImportedTFIDF + _install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _ImportedTFIDF) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1) + + assert fallback_keywords == ["two"] + + +def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]}) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["x"] + + +def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules( + monkeypatch, + analyse_module, + jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])}, + ) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["foo"] + + +def test_extract_keywords_expands_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"]) + + keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3) + + assert "alpha-beta" in keywords + assert "alpha" in keywords + assert "beta" in keywords + assert "during" in keywords + assert "gamma" in keywords + + +def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + + expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"}) + + assert "alpha-during-beta" in expanded + assert "alpha" in expanded + assert "beta" in expanded + assert "during" not in expanded diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py new file mode 100644 index 0000000000..1b1541ddd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py @@ -0,0 +1,6 @@ +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +def test_stopwords_loaded(): + assert "during" in STOPWORDS + assert "the" in STOPWORDS diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py new file mode 100644 index 0000000000..55e22aea0a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document + + +class _KeywordThatRaises(BaseKeyword): + def create(self, texts: list[Document], **kwargs): + return super().create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + return super().add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return super().text_exists(id) + + def delete_by_ids(self, ids: list[str]): + return super().delete_by_ids(ids) + + def delete(self): + return super().delete() + + def search(self, query: str, **kwargs): + return super().search(query, **kwargs) + + +class _KeywordForHelpers(BaseKeyword): + def __init__(self, dataset, existing_ids: set[str] | None = None): + super().__init__(dataset) + self._existing_ids = existing_ids or set() + + def create(self, texts: list[Document], **kwargs): + return self + + def add_texts(self, texts: list[Document], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete(self): + return None + + def search(self, query: str, **kwargs): + return [] + + +def test_abstract_methods_raise_not_implemented(): + keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1")) + + with pytest.raises(NotImplementedError): + keyword.create([]) + + with pytest.raises(NotImplementedError): + keyword.add_texts([]) + + with pytest.raises(NotImplementedError): + keyword.text_exists("doc-1") + + with pytest.raises(NotImplementedError): + keyword.delete_by_ids(["doc-1"]) + + with pytest.raises(NotImplementedError): + keyword.delete() + + with pytest.raises(NotImplementedError): + keyword.search("query") + + +def test_filter_duplicate_texts_removes_existing_doc_ids(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"}) + texts = [ + Document(page_content="keep", metadata={"doc_id": "keep"}), + Document(page_content="duplicate", metadata={"doc_id": "duplicate"}), + SimpleNamespace(page_content="without-metadata", metadata=None), + ] + + filtered = keyword._filter_duplicate_texts(texts) + + assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"] + assert any(text.metadata is None for text in filtered) + + +def test_get_uuids_returns_only_docs_with_metadata(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1")) + texts = [ + Document(page_content="doc-1", metadata={"doc_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "doc-2"}), + SimpleNamespace(page_content="doc-3", metadata=None), + ] + + assert keyword._get_uuids(texts) == ["doc-1", "doc-2"] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py new file mode 100644 index 0000000000..0d969a3270 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py @@ -0,0 +1,84 @@ +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.keyword.keyword_type import KeyWordType +from core.rag.models.document import Document + + +def test_get_keyword_factory_returns_jieba_factory(monkeypatch): + fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba") + + class FakeJieba: + pass + + fake_module.Jieba = FakeJieba + monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module) + + assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba + + +def test_get_keyword_factory_raises_for_unsupported_type(): + with pytest.raises(ValueError, match="Keyword store unsupported is not supported"): + Keyword.get_keyword_factory("unsupported") + + +def test_keyword_initialization_uses_configured_factory(monkeypatch): + dataset = SimpleNamespace(id="dataset-1") + fake_processor = MagicMock() + + monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA) + monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor)) + + keyword = Keyword(dataset) + + assert keyword._keyword_processor is fake_processor + + +def test_keyword_methods_forward_to_processor(): + processor = MagicMock() + processor.text_exists.return_value = True + processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})] + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = processor + + docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})] + keyword.create(docs, foo="bar") + keyword.add_texts(docs, batch=True) + assert keyword.text_exists("doc-1") is True + keyword.delete_by_ids(["doc-1"]) + keyword.delete() + assert keyword.search("query", top_k=1) == processor.search.return_value + + processor.create.assert_called_once_with(docs, foo="bar") + processor.add_texts.assert_called_once_with(docs, batch=True) + processor.text_exists.assert_called_once_with("doc-1") + processor.delete_by_ids.assert_called_once_with(["doc-1"]) + processor.delete.assert_called_once() + processor.search.assert_called_once_with("query", top_k=1) + + +def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes(): + class Processor: + value = 1 + + @staticmethod + def custom(): + return "ok" + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = Processor() + + assert keyword.custom() == "ok" + + with pytest.raises(AttributeError): + _ = keyword.value + + keyword._keyword_processor = None + with pytest.raises(AttributeError): + _ = keyword.missing_method diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py new file mode 100644 index 0000000000..63de4b8af2 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -0,0 +1,1176 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, call, patch +from uuid import uuid4 + +import pytest + +from core.rag.datasource import retrieval_service as retrieval_service_module +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset + + +def create_mock_document( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + This helper function standardizes document creation across tests, + ensuring consistent structure and reducing code duplication. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + + Example: + >>> doc = create_mock_document("Python is great", "doc1", score=0.95) + >>> assert doc.metadata["score"] == 0.95 + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + # Merge additional metadata if provided + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +class _ImmediateFuture: + def __init__(self, exception: Exception | None = None) -> None: + self._exception = exception + self.cancel_called = False + + def exception(self) -> Exception | None: + return self._exception + + def cancel(self) -> None: + self.cancel_called = True + + +class _ImmediateExecutor: + def __init__(self) -> None: + self.futures: list[_ImmediateFuture] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + try: + fn(*args, **kwargs) + future = _ImmediateFuture() + except Exception as exc: # pragma: no cover - only for defensive parity with Future semantics + future = _ImmediateFuture(exc) + self.futures.append(future) + return future + + +class _FakeExecuteScalarResult: + def __init__(self, data: list) -> None: + self._data = data + + def all(self) -> list: + return self._data + + +class _FakeExecuteResult: + def __init__(self, data: list) -> None: + self._data = data + + def scalars(self) -> _FakeExecuteScalarResult: + return _FakeExecuteScalarResult(self._data) + + +class _FakeSummaryQuery: + def __init__(self, summaries: list) -> None: + self._summaries = summaries + + def filter(self, *args, **kwargs): + return self + + def all(self) -> list: + return self._summaries + + +class _FakeSession: + def __init__(self, execute_payloads: list[list], summaries: list) -> None: + self._payloads = list(execute_payloads) + self._summaries = summaries + + def execute(self, stmt): + data = self._payloads.pop(0) if self._payloads else [] + return _FakeExecuteResult(data) + + def query(self, model): + return _FakeSummaryQuery(self._summaries) + + +class _FakeSessionContext: + def __init__(self, session: _FakeSession) -> None: + self._session = session + + def __enter__(self) -> _FakeSession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class _SimpleRetrievalChildChunk: + def __init__(self, id: str, content: str, score: float, position: int) -> None: + self.id = id + self.content = content + self.score = score + self.position = position + + +class _SimpleRetrievalSegment: + def __init__( + self, + segment, + child_chunks: list[_SimpleRetrievalChildChunk] | None = None, + score: float | None = None, + files: list[dict[str, str | int]] | None = None, + summary: str | None = None, + ) -> None: + self.segment = segment + self.child_chunks = child_chunks + self.score = score + self.files = files + self.summary = summary + + +class TestRetrievalServiceInternals: + @pytest.fixture + def internal_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = "dataset-id" + dataset.tenant_id = "tenant-id" + dataset.is_multimodal = False + dataset.doc_form = IndexStructureType.PARENT_CHILD_INDEX + return dataset + + @pytest.fixture + def internal_flask_app(self): + app = MagicMock() + app.app_context.return_value.__enter__ = Mock() + app.app_context.return_value.__exit__.return_value = False + return app + + def test_retrieve_with_attachment_ids_only(self, monkeypatch, internal_dataset): + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset", return_value=internal_dataset), + patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") as mock_retrieve, + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + def side_effect( + flask_app, + retrieval_method, + dataset, + all_documents, + exceptions, + query=None, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + ): + all_documents.append(create_mock_document(f"content-{attachment_id}", attachment_id or "none", 0.9)) + + mock_retrieve.side_effect = side_effect + + results = RetrievalService.retrieve( + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset_id=internal_dataset.id, + query="", + attachment_ids=["att-1", "att-2"], + ) + + assert len(results) == 2 + assert {doc.metadata["doc_id"] for doc in results} == {"att-1", "att-2"} + assert mock_retrieve.call_count == 2 + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + mock_validate.return_value = "validated-condition" + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, + ) + + assert results == expected_documents + mock_validate.assert_called_once() + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition="validated-condition", + ) + + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_returns_empty_when_dataset_not_found(self, mock_scalar): + mock_scalar.return_value = None + + results = RetrievalService.external_retrieve(dataset_id="missing", query="q") + + assert results == [] + + @patch("core.rag.datasource.retrieval_service.Session") + def test_get_dataset_queries_by_id(self, mock_session_class): + expected_dataset = Mock(spec=Dataset) + mock_session = Mock() + mock_session.query.return_value.where.return_value.first.return_value = expected_dataset + mock_session_class.return_value.__enter__.return_value = mock_session + + with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())): + result = RetrievalService._get_dataset("dataset-123") + + assert result == expected_dataset + mock_session.query.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_success(self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.return_value = [create_mock_document("keyword-content", "kw-1", 0.91)] + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "with quotes"', + top_k=5, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + keyword_instance.search.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_dataset_missing(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.side_effect = RuntimeError("keyword failed") + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["keyword failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [create_mock_document("vector-content", "vec-1", 0.7)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + document_ids_filter=["doc-1"], + query_type=QueryType.TEXT_QUERY, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_vector.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_non_multimodal_returns_early( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-1", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == [] + assert exceptions == [] + vector_instance.search_by_file.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_with_vision_reranking( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + reranked_docs = [create_mock_document("image-content-reranked", "img-doc", 0.97)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = True + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) + model_manager.check_model_support_vision.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_without_vision_support( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = [create_mock_document("unused", "unused", 0.1)] + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = False + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == original_docs + assert exceptions == [] + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) + processor_instance.invoke.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_with_reranking_non_multimodal( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("vector-content", "vec-doc", 0.62)] + reranked_docs = [create_mock_document("vector-content-reranked", "vec-doc", 0.89)] + + vector_instance = Mock() + vector_instance.search_by_vector.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_appends_exception_when_vector_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.side_effect = RuntimeError("vector failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == [] + assert exceptions == ["vector failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = [create_mock_document("fulltext", "ft-1", 0.68)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "x"', + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_full_text.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_with_reranking( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("fulltext", "ft-1", 0.68)] + reranked_docs = [create_mock_document("fulltext-reranked", "ft-1", 0.9)] + + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_dataset_not_found(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.side_effect = RuntimeError("fulltext failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["fulltext failed"] + + def test_format_retrieval_documents_with_empty_input_returns_empty_list(self): + assert RetrievalService.format_retrieval_documents([]) == [] + + def test_format_retrieval_documents_without_document_id_returns_empty_list(self): + documents = [Document(page_content="content", metadata={"doc_id": "doc-1", "score": 0.4}, provider="dify")] + + assert RetrievalService.format_retrieval_documents(documents) == [] + + def test_format_retrieval_documents_with_parent_child_summary_and_attachments(self, monkeypatch): + dataset_doc_parent = SimpleNamespace( + id="doc-parent", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + dataset_doc_text = SimpleNamespace(id="doc-text", doc_form="paragraph", dataset_id="dataset-id") + dataset_doc_parent_summary = SimpleNamespace( + id="doc-parent-summary", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + + dataset_query = Mock() + dataset_query.where.return_value.options.return_value.all.return_value = [ + dataset_doc_parent, + dataset_doc_text, + dataset_doc_parent_summary, + ] + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query)) + monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) + monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) + + input_documents = [ + Document( + page_content="child node content", + metadata={"document_id": "doc-parent", "doc_id": "child-node-1", "score": 0.7}, + provider="dify", + ), + Document( + page_content="parent image", + metadata={ + "document_id": "doc-parent", + "doc_id": "attach-node-1", + "doc_type": DocType.IMAGE, + "score": 0.8, + }, + provider="dify", + ), + Document( + page_content="text index node", + metadata={"document_id": "doc-text", "doc_id": "index-node-1", "score": 0.6}, + provider="dify", + ), + Document( + page_content="text image node", + metadata={ + "document_id": "doc-text", + "doc_id": "attach-text-1", + "doc_type": DocType.IMAGE, + "score": 0.65, + }, + provider="dify", + ), + Document( + page_content="summary candidate 1", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-1", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.9", + }, + provider="dify", + ), + Document( + page_content="summary candidate 2", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-2", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.95", + }, + provider="dify", + ), + Document( + page_content="invalid score summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-invalid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "invalid", + }, + provider="dify", + ), + Document( + page_content="valid parent summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-valid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "0.4", + }, + provider="dify", + ), + ] + + child_chunk = SimpleNamespace( + id="child-chunk-1", + segment_id="segment-parent", + index_node_id="child-node-1", + content="child details", + position=2, + ) + segment_parent = SimpleNamespace(id="segment-parent", document_id="doc-parent", index_node_id="parent-node") + segment_text = SimpleNamespace(id="segment-text", document_id="doc-text", index_node_id="index-node-1") + segment_summary = SimpleNamespace(id="segment-summary", document_id="doc-text", index_node_id="summary-node") + segment_parent_summary = SimpleNamespace( + id="segment-parent-summary", + document_id="doc-parent-summary", + index_node_id="summary-parent-node", + ) + + fake_session = _FakeSession( + execute_payloads=[ + [child_chunk], + [segment_text], + [segment_parent, segment_text], + [segment_summary, segment_parent_summary], + ], + summaries=[ + SimpleNamespace(chunk_id="segment-summary", summary_content="summary for text"), + SimpleNamespace(chunk_id="segment-parent-summary", summary_content="summary for parent"), + ], + ) + monkeypatch.setattr( + retrieval_service_module.session_factory, + "create_session", + lambda: _FakeSessionContext(fake_session), + ) + monkeypatch.setattr( + RetrievalService, + "get_segment_attachment_infos", + lambda attachment_ids, session: [ + { + "attachment_id": "attach-node-1", + "attachment_info": { + "id": "attach-node-1", + "name": "img-parent", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://parent", + "size": 11, + }, + "segment_id": "segment-parent", + }, + { + "attachment_id": "attach-text-1", + "attachment_info": { + "id": "attach-text-1", + "name": "img-text", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://text", + "size": 22, + }, + "segment_id": "segment-text", + }, + ], + ) + + result = RetrievalService.format_retrieval_documents(input_documents) + + assert len(result) == 4 + result_by_segment_id = {item.segment.id: item for item in result} + assert result_by_segment_id["segment-summary"].score == pytest.approx(0.95) + assert result_by_segment_id["segment-summary"].summary == "summary for text" + assert result_by_segment_id["segment-parent"].score == pytest.approx(0.8) + assert result_by_segment_id["segment-parent"].files is not None + assert len(result_by_segment_id["segment-parent"].child_chunks or []) == 1 + assert result_by_segment_id["segment-text"].score == pytest.approx(0.65) + assert result_by_segment_id["segment-parent-summary"].score == pytest.approx(0.4) + assert result_by_segment_id["segment-parent-summary"].summary == "summary for parent" + assert result_by_segment_id["segment-parent-summary"].child_chunks == [] + + def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): + rollback = Mock() + monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error"))) + + documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] + + with pytest.raises(RuntimeError, match="db error"): + RetrievalService.format_retrieval_documents(documents) + + rollback.assert_called_once() + + def test_retrieve_internal_returns_early_without_query_or_attachment(self, internal_dataset, internal_flask_app): + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=exceptions, + query=None, + attachment_id=None, + ) + + assert all_documents == [] + assert exceptions == [] + + def test_retrieve_internal_cancels_futures_when_future_has_exception(self, internal_dataset, internal_flask_app): + future_error = Mock() + future_error.exception.return_value = RuntimeError("future failed") + future_ok = Mock() + future_ok.exception.return_value = None + + with ( + patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor, + patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + return_value=[future_error, future_ok], + ), + ): + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = [future_error, future_ok] + mock_executor.return_value.__enter__.return_value = mock_executor_instance + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=[], + query="query", + attachment_id="file-1", + ) + + future_error.cancel.assert_called() + future_ok.cancel.assert_called() + + def test_retrieve_internal_raises_value_error_when_exceptions_exist( + self, monkeypatch, internal_dataset, internal_flask_app + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + with patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") as mock_keyword_search: + mock_keyword_search.side_effect = lambda *args, **kwargs: None + with pytest.raises(ValueError, match="keyword error"): + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.KEYWORD_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=["keyword error"], + query="query", + ) + + def test_retrieve_internal_hybrid_weighted_attachment_flow(self, monkeypatch, internal_dataset, internal_flask_app): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + text_doc = create_mock_document("text", "text-doc", 0.81) + image_doc = create_mock_document("image", "image-doc", 0.72) + fulltext_doc = create_mock_document("full", "full-doc", 0.65) + processed_doc = create_mock_document("processed", "processed-doc", 0.99) + + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") as mock_embedding_search, + patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search") as mock_fulltext, + patch("core.rag.datasource.retrieval_service.DataPostProcessor") as mock_processor_class, + ): + + def embedding_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + query_type=QueryType.TEXT_QUERY, + ): + if query_type == QueryType.IMAGE_QUERY: + all_documents.append(image_doc) + else: + all_documents.append(text_doc) + + mock_embedding_search.side_effect = embedding_side_effect + + def fulltext_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + ): + all_documents.append(fulltext_doc) + + mock_fulltext.side_effect = fulltext_side_effect + processor_instance = Mock() + processor_instance.invoke.return_value = [processed_doc] + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.HYBRID_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=[], + query="query", + attachment_id="file-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + top_k=3, + ) + + assert len(all_documents) == 4 + assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_info_success(self, mock_sign): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1") + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = binding + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result == { + "attachment_info": { + "id": "upload-1", + "name": "file-name", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + mock_sign.assert_called_once_with("upload-1", "png") + + def test_get_segment_attachment_info_returns_none_when_binding_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = None + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self): + upload_query = Mock() + upload_query.where.return_value.first.return_value = None + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self): + upload_query = Mock() + upload_query.where.return_value.all.return_value = [] + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + def test_get_segment_attachment_infos_returns_empty_when_bindings_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_infos_success(self, mock_sign): + upload_file_1 = SimpleNamespace( + id="upload-1", + name="file-1", + extension="png", + mime_type="image/png", + size=42, + ) + upload_file_2 = SimpleNamespace( + id="upload-2", + name="file-2", + extension="jpg", + mime_type="image/jpeg", + size=99, + ) + binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1") + + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [binding] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session) + + assert result == [ + { + "attachment_id": "upload-1", + "attachment_info": { + "id": "upload-1", + "name": "file-1", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + ] + mock_sign.assert_has_calls( + [ + call("upload-1", "png"), + call("upload-2", "jpg"), + ] + ) + assert mock_sign.call_count == 2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py new file mode 100644 index 0000000000..e063a49f22 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py @@ -0,0 +1,74 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module +from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory + + +def test_validate_distance_function_accepts_supported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + assert factory._validate_distance_function("cosine") == "cosine" + assert factory._validate_distance_function("euclidean") == "euclidean" + + +def test_validate_distance_function_rejects_unsupported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + with pytest.raises(ValueError, match="Invalid distance function"): + factory._validate_distance_function("dot_product") + + +def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" + + +def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", + index_struct_dict=None, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + vector_cls.assert_called_once() + assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2" + assert dataset.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py new file mode 100644 index 0000000000..545565cdf4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -0,0 +1,133 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig +from core.rag.models.document import Document + + +def test_init_prefers_openapi_when_api_config_is_provided(): + api_config = AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls: + vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None) + + assert vector.analyticdb_vector == "openapi_runner" + openapi_cls.assert_called_once_with("COLLECTION", api_config) + + +def test_init_uses_sql_implementation_when_api_config_is_missing(): + sql_config = AnalyticdbVectorBySqlConfig( + host="localhost", + port=5432, + account="account", + account_password="password", + min_connection=1, + max_connection=2, + namespace="dify", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls: + vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config) + + assert vector.analyticdb_vector == "sql_runner" + sql_cls.assert_called_once_with("COLLECTION", sql_config) + + +def test_init_raises_when_both_configs_are_missing(): + with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"): + AnalyticdbVector("COLLECTION", api_config=None, sql_config=None) + + +def test_vector_methods_delegate_to_underlying_implementation(): + runner = MagicMock() + runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})] + runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})] + runner.text_exists.return_value = True + + vector = AnalyticdbVector.__new__(AnalyticdbVector) + vector.analyticdb_vector = runner + + texts = [Document(page_content="hello", metadata={"doc_id": "d1"})] + vector.create(texts=texts, embeddings=[[0.1, 0.2]]) + vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]]) + assert vector.text_exists("d1") is True + vector.delete_by_ids(["d1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value + assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value + vector.delete() + + runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) + runner.delete_by_ids.assert_called_once_with(["d1"]) + runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") + runner.delete.assert_called_once() + + +def test_get_type_is_analyticdb(): + vector = AnalyticdbVector.__new__(AnalyticdbVector) + assert vector.get_type() == "analyticdb" + + +def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "auto_collection" + assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig) + assert args[2] is None + assert dataset.index_struct is not None + + +def test_factory_builds_sql_config_when_host_is_present(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "existing" + assert args[1] is None + assert isinstance(args[2], AnalyticdbVectorBySqlConfig) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py new file mode 100644 index 0000000000..45777774d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -0,0 +1,384 @@ +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from core.rag.models.document import Document + + +def _request_class(name: str): + class _Request: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + _Request.__name__ = name + return _Request + + +def _install_openapi_stubs(monkeypatch): + gpdb_package = types.ModuleType("alibabacloud_gpdb20160503") + gpdb_package.__path__ = [] + gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models") + for class_name in [ + "InitVectorDatabaseRequest", + "DescribeNamespaceRequest", + "CreateNamespaceRequest", + "DescribeCollectionRequest", + "CreateCollectionRequest", + "UpsertCollectionDataRequestRows", + "UpsertCollectionDataRequest", + "QueryCollectionDataRequest", + "DeleteCollectionDataRequest", + "DeleteCollectionRequest", + ]: + setattr(gpdb_models, class_name, _request_class(class_name)) + + class _Client: + def __init__(self, config): + self.config = config + + gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client") + gpdb_client.Client = _Client + gpdb_package.models = gpdb_models + + tea_openapi = types.ModuleType("alibabacloud_tea_openapi") + tea_openapi.__path__ = [] + tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models") + + class OpenApiConfig: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + tea_openapi_models.Config = OpenApiConfig + tea_openapi.models = tea_openapi_models + + tea_package = types.ModuleType("Tea") + tea_package.__path__ = [] + tea_exceptions = types.ModuleType("Tea.exceptions") + + class TeaError(Exception): + def __init__(self, status_code=None, **kwargs): + super().__init__("TeaException") + status_code = kwargs.get("statusCode", status_code) + self.statusCode = status_code + self.status_code = status_code + + tea_exceptions.TeaException = TeaError + tea_package.exceptions = tea_exceptions + + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models) + monkeypatch.setitem(sys.modules, "Tea", tea_package) + monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions) + + return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig) + + +def _config() -> AnalyticdbVectorOpenAPIConfig: + return AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("access_key_id", "", "ANALYTICDB_KEY_ID"), + ("access_key_secret", "", "ANALYTICDB_KEY_SECRET"), + ("region_id", "", "ANALYTICDB_REGION_ID"), + ("instance_id", "", "ANALYTICDB_INSTANCE_ID"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"), + ], +) +def test_openapi_config_validation(field, value, error_message): + values = _config().model_dump() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorOpenAPIConfig.model_validate(values) + + +def test_openapi_config_to_client_params(): + config = _config() + params = config.to_analyticdb_client_params() + + assert params["access_key_id"] == "ak" + assert params["access_key_secret"] == "sk" + assert params["region_id"] == "cn-hangzhou" + assert params["read_timeout"] == 60000 + + +def test_init_creates_openapi_client_and_runs_initialize(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + initialize_mock = MagicMock() + monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock) + + vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config()) + + assert vector._collection_name == "collection_1" + assert isinstance(vector._client_config, stubs.OpenApiConfig) + assert vector._client_config.user_agent == "dify" + assert vector._client_config.access_key_id == "ak" + assert vector._client.config is vector._client_config + initialize_mock.assert_called_once_with() + + +def test_initialize_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + vector._create_namespace_if_not_exists.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + vector._create_namespace_if_not_exists.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_initialize_vector_database_calls_openapi_client(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._initialize_vector_database() + + request = vector._client.init_vector_database.call_args.args[0] + assert request.dbinstance_id == "instance-1" + assert request.region_id == "cn-hangzhou" + assert request.manager_account == "account" + assert request.manager_account_password == "password" + + +def test_create_namespace_creates_when_namespace_not_found(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404) + + vector._create_namespace_if_not_exists() + + vector._client.create_namespace.assert_called_once() + + +def test_create_namespace_raises_on_unexpected_api_error(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create namespace"): + vector._create_namespace_if_not_exists() + + +def test_create_namespace_noop_when_namespace_exists(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._create_namespace_if_not_exists() + + vector._client.describe_namespace.assert_called_once() + vector._client.create_namespace.assert_not_called() + + +def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.create_collection.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.describe_collection.assert_not_called() + vector._client.create_collection.assert_not_called() + + +def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create collection collection_1"): + vector._create_collection_if_not_exists(embedding_dimension=512) + + +def test_openapi_add_delete_and_search_methods(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + documents = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + embeddings = [[0.1, 0.2], [0.2, 0.3]] + vector.add_texts(documents, embeddings) + + upsert_request = vector._client.upsert_collection_data.call_args.args[0] + assert upsert_request.collection == "collection_1" + assert len(upsert_request.rows) == 1 + + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()])) + ) + assert vector.text_exists("d1") is True + + vector.delete_by_ids(["d1", "d2"]) + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')" + + vector.delete_by_metadata_field("document_id", "doc-1") + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'" + + match_high = SimpleNamespace( + score=0.9, + metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"}, + values=SimpleNamespace(value=[1.0, 2.0]), + ) + match_low = SimpleNamespace( + score=0.1, + metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"}, + values=SimpleNamespace(value=[3.0, 4.0]), + ) + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high])) + ) + + docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + assert len(docs_by_vector) == 1 + assert docs_by_vector[0].page_content == "high" + assert docs_by_vector[0].metadata["score"] == 0.9 + + docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2) + assert len(docs_by_text) == 1 + assert docs_by_text[0].page_content == "high" + + +def test_text_exists_returns_false_when_matches_empty(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[])) + ) + + assert vector.text_exists("missing-id") is False + + +def test_openapi_delete_success(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector.delete() + vector._client.delete_collection.assert_called_once() + + +def test_openapi_delete_propagates_errors(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.delete_collection.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.delete() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py new file mode 100644 index 0000000000..8f1206696b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -0,0 +1,427 @@ +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import psycopg2.errors +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( + AnalyticdbVectorBySql, + AnalyticdbVectorBySqlConfig, +) +from core.rag.models.document import Document + + +def _config_values() -> dict: + return { + "host": "localhost", + "port": 5432, + "account": "account", + "account_password": "password", + "min_connection": 1, + "max_connection": 2, + "namespace": "dify", + } + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("host", "", "ANALYTICDB_HOST"), + ("port", 0, "ANALYTICDB_PORT"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"), + ("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"), + ], +) +def test_sql_config_required_fields(field, value, error_message): + values = _config_values() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_sql_config_rejects_min_connection_greater_than_max_connection(): + values = _config_values() + values["min_connection"] = 10 + values["max_connection"] = 2 + + with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_initialize_skips_when_cache_exists(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + sql_module.redis_client.set.assert_called_once() + + +def test_create_connection_pool_uses_psycopg2_pool(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + pool_instance = MagicMock() + monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance)) + + pool = vector._create_connection_pool() + + assert pool is pool_instance + sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once() + + +def test_get_cursor_context_manager_handles_connection_lifecycle(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + cursor = MagicMock() + connection = MagicMock() + connection.cursor.return_value = cursor + pool = MagicMock() + pool.getconn.return_value = connection + vector.pool = pool + + with vector._get_cursor() as cur: + assert cur is cursor + + cursor.close.assert_called_once() + connection.commit.assert_called_once() + pool.putconn.assert_called_once_with(connection) + + +def test_add_texts_inserts_only_documents_with_metadata(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id") + monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock()) + + docs = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]]) + + execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args + assert execute_args[0] is cursor + assert len(execute_args[2]) == 1 + + +def test_text_exists_returns_true_and_false_based_on_query_result(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.fetchone.return_value = ("row",) + assert vector.text_exists("d1") is True + + cursor.fetchone.return_value = None + assert vector.text_exists("d1") is False + + +def test_delete_by_ids_handles_empty_input_and_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_ids(["d1"]) + + +def test_delete_by_metadata_field_handles_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_metadata_field("document_id", "doc-1") + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_vector_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k) + + +def test_search_by_vector_returns_documents_above_threshold(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}), + ("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content 1" + assert docs[0].metadata["score"] == 0.8 + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_full_text_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=invalid_top_k) + + +def test_search_by_full_text_returns_documents(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + assert docs[0].page_content == "content 1" + + +def test_delete_drops_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete() + + cursor.execute.assert_called_once() + + +def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch): + config = AnalyticdbVectorBySqlConfig(**_config_values()) + created_pool = MagicMock() + + monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock()) + monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool)) + + vector = AnalyticdbVectorBySql("My_Collection", config) + + assert vector._collection_name == "my_collection" + assert vector.table_name == "dify.my_collection" + assert vector.databaseName == "knowledgebase" + assert vector.pool is created_pool + + +def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + bootstrap_cursor.execute.side_effect = RuntimeError("database already exists") + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + + def _execute(sql, *args, **kwargs): + if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql: + raise RuntimeError("already exists") + + worker_cursor.execute.side_effect = _execute + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + vector._initialize_vector_database() + + bootstrap_cursor.close.assert_called_once() + bootstrap_connection.close.assert_called_once() + vector._create_connection_pool.assert_called_once() + assert any( + "CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0] + for call in worker_cursor.execute.call_args_list + ) + assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list) + + +def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable") + + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + with pytest.raises(RuntimeError, match="Failed to create zhparser extension"): + vector._initialize_vector_database() + + worker_connection.rollback.assert_called_once() + + +def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + vector._create_collection_if_not_exists(embedding_dimension=3) + + assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) + assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) + sql_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + cursor.execute.side_effect = RuntimeError("permission denied") + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + with pytest.raises(RuntimeError, match="permission denied"): + vector._create_collection_if_not_exists(embedding_dimension=3) + + +def test_delete_methods_raise_when_error_is_not_missing_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.execute.side_effect = RuntimeError("unexpected delete failure") + with pytest.raises(RuntimeError, match="unexpected delete failure"): + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("unexpected metadata failure") + with pytest.raises(RuntimeError, match="unexpected metadata failure"): + vector.delete_by_metadata_field("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py new file mode 100644 index 0000000000..c46c3d5e4b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -0,0 +1,542 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_pymochow_modules(): + pymochow = types.ModuleType("pymochow") + pymochow.__path__ = [] + pymochow_auth = types.ModuleType("pymochow.auth") + pymochow_auth.__path__ = [] + pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials") + pymochow_configuration = types.ModuleType("pymochow.configuration") + pymochow_exception = types.ModuleType("pymochow.exception") + pymochow_model = types.ModuleType("pymochow.model") + pymochow_model.__path__ = [] + pymochow_model_database = types.ModuleType("pymochow.model.database") + pymochow_model_enum = types.ModuleType("pymochow.model.enum") + pymochow_model_schema = types.ModuleType("pymochow.model.schema") + pymochow_model_table = types.ModuleType("pymochow.model.table") + + class _SimpleObject: + def __init__(self, *args, **kwargs): + self.args = args + for key, value in kwargs.items(): + setattr(self, key, value) + + class ServerError(Exception): + def __init__(self, code): + super().__init__(f"server error {code}") + self.code = code + + class ServerErrCode: + TABLE_NOT_EXIST = 1001 + DB_ALREADY_EXIST = 1002 + + class IndexType: + __members__ = {"HNSW": "HNSW"} + + class MetricType: + __members__ = {"IP": "IP"} + + class IndexState: + NORMAL = "NORMAL" + + class TableState: + NORMAL = "NORMAL" + + class InvertedIndexAnalyzer: + DEFAULT_ANALYZER = "DEFAULT_ANALYZER" + + class InvertedIndexParseMode: + COARSE_MODE = "COARSE_MODE" + + class InvertedIndexFieldAttribute: + ANALYZED = "ANALYZED" + + class FieldType: + STRING = "STRING" + TEXT = "TEXT" + JSON = "JSON" + FLOAT_VECTOR = "FLOAT_VECTOR" + + pymochow.MochowClient = _SimpleObject + pymochow_credentials.BceCredentials = _SimpleObject + pymochow_configuration.Configuration = _SimpleObject + pymochow_exception.ServerError = ServerError + pymochow_model_database.Database = _SimpleObject + + pymochow_model_enum.FieldType = FieldType + pymochow_model_enum.IndexState = IndexState + pymochow_model_enum.IndexType = IndexType + pymochow_model_enum.MetricType = MetricType + pymochow_model_enum.ServerErrCode = ServerErrCode + pymochow_model_enum.TableState = TableState + + for cls_name in [ + "AutoBuildRowCountIncrement", + "Field", + "FilteringIndex", + "HNSWParams", + "InvertedIndex", + "InvertedIndexParams", + "Schema", + "VectorIndex", + ]: + setattr(pymochow_model_schema, cls_name, _SimpleObject) + pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer + pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute + pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode + + for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]: + setattr(pymochow_model_table, cls_name, _SimpleObject) + + pymochow.auth = pymochow_auth + pymochow.model = pymochow_model + pymochow_auth.bce_credentials = pymochow_credentials + pymochow_model.database = pymochow_model_database + pymochow_model.enum = pymochow_model_enum + pymochow_model.schema = pymochow_model_schema + pymochow_model.table = pymochow_model_table + + modules = { + "pymochow": pymochow, + "pymochow.auth": pymochow_auth, + "pymochow.auth.bce_credentials": pymochow_credentials, + "pymochow.configuration": pymochow_configuration, + "pymochow.exception": pymochow_exception, + "pymochow.model": pymochow_model, + "pymochow.model.database": pymochow_model_database, + "pymochow.model.enum": pymochow_model_enum, + "pymochow.model.schema": pymochow_model_schema, + "pymochow.model.table": pymochow_model_table, + } + return modules + + +@pytest.fixture +def baidu_module(monkeypatch): + for name, module in _build_fake_pymochow_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + import core.rag.datasource.vdb.baidu.baidu_vector as module + + return importlib.reload(module) + + +def test_baidu_config_validation(baidu_module): + values = { + "endpoint": "https://example.com", + "account": "account", + "api_key": "key", + "database": "database", + } + config = baidu_module.BaiduConfig.model_validate(values) + assert config.endpoint == "https://example.com" + + for key, error_message in [ + ("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"), + ("account", "BAIDU_VECTOR_DB_ACCOUNT"), + ("api_key", "BAIDU_VECTOR_DB_API_KEY"), + ("database", "BAIDU_VECTOR_DB_DATABASE"), + ]: + invalid = dict(values) + invalid[key] = "" + with pytest.raises(ValueError, match=error_message): + baidu_module.BaiduConfig.model_validate(invalid) + + +def test_get_search_result_handles_metadata_and_threshold(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace( + rows=[ + {"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9}, + {"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4}, + {"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95}, + ] + ) + + docs = vector._get_search_res(response, score_threshold=0.8) + + assert len(docs) == 2 + assert docs[0].page_content == "doc1" + assert docs[0].metadata["score"] == 0.9 + assert docs[1].page_content == "doc3" + + +def test_delete_by_ids_and_delete_by_metadata_field(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + vector._collection_name = "collection_1" + + vector.delete_by_ids([]) + table.delete.assert_not_called() + + vector.delete_by_ids(["id1", "id2"]) + table.delete.assert_called_once() + + table.delete.reset_mock() + vector.delete_by_metadata_field("source", 'abc"def') + delete_filter = table.delete.call_args.kwargs["filter"] + assert delete_filter == 'metadata["source"] = "abc\\"def"' + + +def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + + vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + vector.delete() + + vector._db.drop_table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector.delete() + + +def test_init_database_uses_existing_or_creates_when_missing(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + + vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")] + vector._client.database.return_value = "existing_db" + assert vector._init_database() == "existing_db" + + vector._client.list_databases.return_value = [] + vector._client.database.return_value = "created_db" + vector._client.create_database.side_effect = None + assert vector._init_database() == "created_db" + + vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST) + assert vector._init_database() == "created_db" + + +def test_table_existed_checks_table_access(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._db.table.return_value = MagicMock() + + assert vector._table_existed() is True + + vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + assert vector._table_existed() is False + + vector._db.table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector._table_existed() + + +def test_search_methods_delegate_to_database_table(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})]) + + table = MagicMock() + vector._db.table.return_value = table + table.search.return_value = "vector_result" + table.bm25_search.return_value = "bm25_result" + + result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + + assert result1 == vector._get_search_res.return_value + assert result2 == vector._get_search_res.return_value + assert vector._get_search_res.call_count == 2 + + +def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection" + assert dataset.index_struct is not None + + +def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch): + init_client = MagicMock(return_value="client") + init_database = MagicMock(return_value="database") + monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client) + monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database) + + config = baidu_module.BaiduConfig( + endpoint="https://example.com", + account="account", + api_key="key", + database="db", + ) + vector = baidu_module.BaiduVector(collection_name="my_collection", config=config) + + assert vector.get_type() == baidu_module.VectorType.BAIDU + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection" + assert vector._client == "client" + assert vector._db == "database" + + vector._create_table = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="p1", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_table.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_rows(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]]) + + assert table.upsert.call_count == 1 + inserted_rows = table.upsert.call_args.kwargs["rows"] + assert len(inserted_rows) == 2 + + +def test_add_texts_batches_more_than_batch_size(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"}) + for idx in range(1001) + ] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + vector.add_texts(docs, embeddings) + + assert table.upsert.call_count == 2 + assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000 + assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1 + + +def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + table.query.return_value = SimpleNamespace(code=0) + assert vector.text_exists("id-1") is True + + table.query.return_value = SimpleNamespace(code=1) + assert vector.text_exists("id-1") is False + + table.query.return_value = None + assert vector.text_exists("id-1") is False + + +def test_get_search_result_handles_invalid_metadata_json(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}]) + + docs = vector._get_search_res(response, score_threshold=0.1) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.7 + assert "document_id" not in docs[0].metadata + + +def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch): + credentials = MagicMock(return_value="credentials") + configuration = MagicMock(return_value="configuration") + client_cls = MagicMock(return_value="client") + monkeypatch.setattr(baidu_module, "BceCredentials", credentials) + monkeypatch.setattr(baidu_module, "Configuration", configuration) + monkeypatch.setattr(baidu_module, "MochowClient", client_cls) + + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + + client = vector._init_client(config) + + assert client == "client" + credentials.assert_called_once_with("account", "key") + configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + client_cls.assert_called_once_with("configuration") + + +def test_init_database_raises_for_unknown_create_database_error(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + vector._client.list_databases.return_value = [] + vector._client.create_database.side_effect = baidu_module.ServerError(9999) + + with pytest.raises(baidu_module.ServerError): + vector._init_database() + + +def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + table = MagicMock() + table.state = baidu_module.TableState.NORMAL + vector._db.describe_table.return_value = table + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock()) + + # Cached table skips all work. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Existing table also skips creation. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + vector._table_existed.return_value = True + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Create table when cache is empty and table does not exist. + vector._table_existed.return_value = False + vector._create_table(3) + vector._db.create_table.assert_called_once() + baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600) + table.rebuild_index.assert_called_once_with(vector.vector_index) + vector._wait_for_index_ready.assert_called_once_with(table, 3600) + + +def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + vector._client_config = SimpleNamespace( + index_type="INVALID", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_table(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "INVALID" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_table(3) + + +def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + vector._db.describe_table.return_value = SimpleNamespace(state="CREATING") + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301])) + + with pytest.raises(TimeoutError, match="Table creation timeout"): + vector._create_table(3) + + +def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py new file mode 100644 index 0000000000..44427b7d87 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py @@ -0,0 +1,199 @@ +import importlib +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_chroma_modules(): + chroma = types.ModuleType("chromadb") + chroma.DEFAULT_TENANT = "default_tenant" + chroma.DEFAULT_DATABASE = "default_database" + + class Settings: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class QueryResult(UserDict): + pass + + class _Collection: + def __init__(self): + self.upsert = MagicMock() + self.delete = MagicMock() + self.query = MagicMock() + self.get = MagicMock(return_value={}) + + class _Client: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.collection = _Collection() + self.get_or_create_collection = MagicMock(return_value=self.collection) + self.delete_collection = MagicMock() + + chroma.Settings = Settings + chroma.QueryResult = QueryResult + chroma.HttpClient = _Client + return chroma + + +@pytest.fixture +def chroma_module(monkeypatch): + fake_chroma = _build_fake_chroma_modules() + monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) + import core.rag.datasource.vdb.chroma.chroma_vector as module + + return importlib.reload(module) + + +def test_chroma_config_to_params_builds_expected_payload(chroma_module): + config = chroma_module.ChromaConfig( + host="localhost", + port=8000, + tenant="tenant-1", + database="db-1", + auth_provider="provider", + auth_credentials="credentials", + ) + + params = config.to_chroma_params() + + assert params["host"] == "localhost" + assert params["port"] == 8000 + assert params["tenant"] == "tenant-1" + assert params["database"] == "db-1" + assert params["ssl"] is False + assert params["settings"].chroma_client_auth_provider == "provider" + assert params["settings"].chroma_client_auth_credentials == "credentials" + + +def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock()) + + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create_collection("collection_1") + + vector._client.get_or_create_collection.assert_called_once_with("collection_1") + chroma_module.redis_client.set.assert_called_once() + + +def test_create_with_empty_texts_is_noop(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create([], []) + vector._client.get_or_create_collection.assert_not_called() + + +def test_create_with_texts_creates_collection_and_upserts(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._client.get_or_create_collection.assert_called() + vector._client.collection.upsert.assert_called_once() + + +def test_delete_methods_and_text_exists(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + + vector.delete_by_ids([]) + vector._client.collection.delete.assert_not_called() + + vector.delete_by_ids(["id-1"]) + vector._client.collection.delete.assert_called_with(ids=["id-1"]) + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}}) + + vector._client.collection.get.return_value = {"ids": ["id-1"]} + assert vector.text_exists("id-1") is True + vector._client.collection.get.return_value = {} + assert vector.text_exists("id-2") is False + + vector.delete() + vector._client.delete_collection.assert_called_once_with("collection_1") + + +def test_search_by_vector_handles_empty_results(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []} + + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + +def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = { + "ids": [["id-1", "id-2"]], + "documents": [["doc high", "doc low"]], + "metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]], + "distances": [[0.1, 0.8]], + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc high" + assert docs[0].metadata["score"] == 0.9 + + +def test_search_by_full_text_returns_empty_list(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + assert vector.search_by_full_text("query") == [] + + +def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch): + factory = chroma_module.ChromaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None) + + with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py new file mode 100644 index 0000000000..0ce5c04dd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py @@ -0,0 +1,927 @@ +import importlib +import queue +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickzetta_module(): + clickzetta = types.ModuleType("clickzetta") + + class _FakeCursor: + def __init__(self): + self.execute = MagicMock() + self.executemany = MagicMock() + self.fetchall = MagicMock(return_value=[]) + self.fetchone = MagicMock(return_value=(0,)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _FakeConnection: + def __init__(self): + self.cursor_obj = _FakeCursor() + + def cursor(self): + return self.cursor_obj + + def close(self): + return None + + def connect(**_kwargs): + return _FakeConnection() + + clickzetta.connect = connect + return clickzetta + + +@pytest.fixture +def clickzetta_module(monkeypatch): + monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) + import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.ClickzettaConfig( + username="username", + password="password", + instance="instance", + service="service", + workspace="workspace", + vcluster="cluster", + schema_name="dify", + ) + + +@pytest.mark.parametrize( + ("field", "error_message"), + [ + ("username", "CLICKZETTA_USERNAME"), + ("password", "CLICKZETTA_PASSWORD"), + ("instance", "CLICKZETTA_INSTANCE"), + ("service", "CLICKZETTA_SERVICE"), + ("workspace", "CLICKZETTA_WORKSPACE"), + ("vcluster", "CLICKZETTA_VCLUSTER"), + ("schema_name", "CLICKZETTA_SCHEMA"), + ], +) +def test_clickzetta_config_validation(clickzetta_module, field, error_message): + values = _config(clickzetta_module).model_dump() + values[field] = "" + with pytest.raises(ValueError, match=error_message): + clickzetta_module.ClickzettaConfig.model_validate(values) + + +def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1") + assert parsed["doc_id"] == "row-1" + assert parsed["document_id"] == "doc-1" + + parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2") + assert parsed_double["doc_id"] == "row-2" + assert parsed_double["document_id"] == "doc-2" + + parsed_fallback = vector._parse_metadata("not-json", "row-3") + assert parsed_fallback["doc_id"] == "row-3" + assert parsed_fallback["document_id"] == "row-3" + + +def test_safe_doc_id_and_vector_format_helpers(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3" + assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF" + assert vector._safe_doc_id("ab c;\n") == "abc" + assert len(vector._safe_doc_id("a" * 300)) == 255 + + +def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_not_found(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_not_found + assert vector._table_exists() is False + + @contextmanager + def _ctx_other_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("permission denied") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_other_error + assert vector._table_exists() is False + + +def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.text_exists("doc-1") is False + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchone.return_value = (1,) + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + assert vector.text_exists("doc-1") is True + + +def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + + vector.delete_by_ids([]) + vector._execute_write.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + vector.delete_by_ids(["doc-1"]) + vector._execute_write.assert_not_called() + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_not_called() + + +def test_search_short_circuit_behaviors(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + vector._config.enable_inverted_index = False + assert vector.search_by_full_text("query", top_k=2) == [] + + +def test_search_by_like_returns_documents_with_default_score(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content" + assert docs[0].metadata["score"] == 0.5 + + +def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch): + factory = clickzetta_module.ClickzettaVectorFactory() + dataset = SimpleNamespace(id="dataset-1") + + monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance") + + with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "collection" + + +def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch): + clickzetta_module.ClickzettaConnectionPool._instance = None + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + + pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance() + pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance() + key = pool_1._get_config_key(_config(clickzetta_module)) + + assert pool_1 is pool_2 + assert "username:instance:service:workspace:cluster:dify" in key + + +def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + connection = MagicMock() + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr( + clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection]) + ) + pool._configure_connection = MagicMock() + + created = pool._create_connection(config) + + assert created is connection + assert clickzetta_module.clickzetta.connect.call_count == 2 + pool._configure_connection.assert_called_once_with(connection) + + +def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + pool._create_connection(config) + + +def test_connection_pool_configure_and_validate_connection(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection = MagicMock() + connection.cursor.return_value = cursor + + pool._configure_connection(connection) + assert cursor.execute.call_count >= 2 + assert pool._is_connection_valid(connection) is True + + bad_connection = MagicMock() + bad_connection.cursor.side_effect = RuntimeError("bad connection") + assert pool._is_connection_valid(bad_connection) is False + monkeypatch.undo() + + +def test_connection_pool_configure_connection_swallows_errors(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + connection = MagicMock() + connection.cursor.side_effect = RuntimeError("cannot configure") + + pool._configure_connection(connection) + monkeypatch.undo() + + +def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + key = pool._get_config_key(config) + + created_connection = MagicMock() + pool._create_connection = MagicMock(return_value=created_connection) + first = pool.get_connection(config) + assert first is created_connection + + reusable_connection = MagicMock() + pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())] + pool._is_connection_valid = MagicMock(return_value=True) + reused = pool.get_connection(config) + assert reused is reusable_connection + + expired_connection = MagicMock() + pool._pools[key] = [(expired_connection, 0.0)] + pool._is_connection_valid = MagicMock(return_value=False) + monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0)) + pool.get_connection(config) + expired_connection.close.assert_called_once() + + random_connection = MagicMock() + pool._is_connection_valid = MagicMock(return_value=True) + pool.return_connection(config, random_connection) + assert len(pool._pools[key]) == 1 + + pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)] + pool._connection_timeout = 10 + pool._cleanup_expired_connections() + assert len(pool._pools[key]) == 1 + + unknown_pool = MagicMock() + pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool) + unknown_pool.close.assert_called_once() + + pool.shutdown() + assert pool._shutdown is True + + +def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True)) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + self.started = False + + def start(self): + self.started = True + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + + assert pool._cleanup_thread.started is True + pool._cleanup_expired_connections.assert_called_once() + + +def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch): + pool = MagicMock() + pool.get_connection.return_value = "conn" + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool)) + monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock()) + + vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module)) + assert vector._table_name == "my_collection" + + assert vector._get_connection() == "conn" + vector._return_connection("conn") + pool.return_connection.assert_called_with(vector._config, "conn") + + with vector.get_connection_context() as conn: + assert conn == "conn" + assert pool.return_connection.call_count >= 2 + + assert vector.get_type() == "clickzetta" + assert vector._ensure_connection() == "conn" + + +def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch): + class _Thread: + def __init__(self, target, daemon): + self.target = target + self.daemon = daemon + self.started = 0 + + def start(self): + self.started += 1 + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_thread = None + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._init_write_queue() + clickzetta_module.ClickzettaVector._init_write_queue() + assert clickzetta_module.ClickzettaVector._write_thread.started == 1 + + result_queue_ok = queue.Queue() + result_queue_fail = queue.Queue() + clickzetta_module.ClickzettaVector._write_queue = queue.Queue() + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok)) + clickzetta_module.ClickzettaVector._write_queue.put( + (lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail) + ) + clickzetta_module.ClickzettaVector._write_queue.put(None) + clickzetta_module.ClickzettaVector._write_worker() + + assert result_queue_ok.get() == (True, 2) + failed = result_queue_fail.get() + assert failed[0] is False + assert isinstance(failed[1], RuntimeError) + + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + clickzetta_module.ClickzettaVector._write_queue = None + with pytest.raises(RuntimeError, match="Write queue not initialized"): + vector._execute_write(lambda: None) + + class _ImmediateSuccessQueue: + def put(self, task): + func, args, kwargs, result_q = task + result_q.put((True, func(*args, **kwargs))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue() + assert vector._execute_write(lambda x: x * 2, 3) == 6 + + class _ImmediateFailQueue: + def put(self, task): + _, _, _, result_q = task + result_q.put((False, ValueError("write failed"))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue() + with pytest.raises(ValueError, match="write failed"): + vector._execute_write(lambda: None) + + +def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_exists(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_exists + assert vector._table_exists() is True + + vector._execute_write = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="content", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._execute_write.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_table_and_indexes_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._create_vector_index = MagicMock() + vector._create_inverted_index = MagicMock() + + vector._table_exists = MagicMock(return_value=True) + vector._create_table_and_indexes([[0.1, 0.2]]) + vector._create_vector_index.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._create_table_and_indexes([[0.1, 0.2, 0.3]]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_called_once() + + vector._config.enable_inverted_index = False + vector._create_vector_index.reset_mock() + vector._create_inverted_index.reset_mock() + vector._create_table_and_indexes([]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_not_called() + + +def test_create_vector_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show index failed"), None] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("already exists")] + cursor.fetchall.return_value = [] + vector._create_vector_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("unexpected")] + cursor.fetchall.return_value = [] + with pytest.raises(RuntimeError, match="unexpected"): + vector._create_vector_index(cursor) + + +def test_create_inverted_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show failed"), None] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [ + None, + RuntimeError("already has index"), + None, + ] + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("other create failure")] + cursor.fetchall.return_value = [] + vector._create_inverted_index(cursor) + + +def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.batch_size = 2 + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id)) + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2"}), + Document(page_content="doc-3", metadata={"doc_id": "id-3"}), + ] + vectors = [[0.1], [0.2], [0.3]] + + vector.add_texts([], []) + vector._execute_write.assert_not_called() + + added_ids = vector.add_texts(docs, vectors) + assert added_ids == ["id-1", "id-2", "id-3"] + assert vector._execute_write.call_count == 2 + assert vector._execute_write.call_args_list[0].args == ( + vector._insert_batch, + docs[:2], + vectors[:2], + ["id-1", "id-2"], + 0, + 2, + 2, + ) + assert vector._execute_write.call_args_list[1].args == ( + vector._insert_batch, + docs[2:], + vectors[2:], + ["id-3"], + 2, + 2, + 2, + ) + + vector._insert_batch([], [], [], 0, 2, 1) + vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1) + + bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}}) + good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [bad_doc, good_doc], + [[0.1, 0.2], [0.3, 0.4]], + ["id-bad", "id-good"], + 0, + 2, + 1, + ) + + @contextmanager + def _ctx_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.executemany.side_effect = RuntimeError("insert failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_error + with pytest.raises(RuntimeError, match="insert failed"): + vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1) + + monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id") + vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector) + assert vector._safe_doc_id("") == "generated-id" + assert vector._safe_doc_id("!!!") == "generated-id" + + +def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._table_exists = MagicMock(return_value=True) + + vector.delete_by_ids(["id-1", "id-2"]) + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl + + vector._execute_write.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl + + vector._safe_doc_id = MagicMock(side_effect=lambda x: x) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._delete_by_ids_impl(["id-1", "id-2"]) + vector._delete_by_metadata_field_impl("document_id", "doc-1") + + +def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.vector_distance_function = "cosine_distance" + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + cosine_docs = vector.search_by_vector( + [0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"} + ) + assert cosine_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._config.vector_distance_function = "l2_distance" + l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2) + + +def test_search_by_full_text_success_and_fallback(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx_success(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'), + ("seg-2", "content-2", "invalid-json"), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_success + docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1}) + assert len(docs) == 2 + assert docs[0].metadata["score"] == 1.0 + assert docs[1].metadata["doc_id"] == "seg-2" + + @contextmanager + def _ctx_failure(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("full text failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_failure + vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})]) + fallback_docs = vector.search_by_full_text("query", top_k=1) + assert fallback_docs == vector._search_by_like.return_value + + +def test_search_by_like_missing_table_and_delete_table(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=False) + assert vector._search_by_like("query", top_k=1) == [] + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector.delete() + + +def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._pools = {} + pool._pool_locks = {} + pool._max_pool_size = 1 + pool._connection_timeout = 10 + pool._lock = clickzetta_module.threading.Lock() + pool._shutdown = False + + config = _config(clickzetta_module) + key = pool._get_config_key(config) + pool._pools[key] = [(MagicMock(), 1.0)] + pool._pool_locks[key] = clickzetta_module.threading.Lock() + pool._is_connection_valid = MagicMock(return_value=False) + + conn = MagicMock() + pool.return_connection(config, conn) + conn.close.assert_called_once() + + pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)] + pool._cleanup_expired_connections() + pool.shutdown() + assert pool._shutdown is True + + +def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + + def _cleanup_then_fail(): + pool._shutdown = True + raise RuntimeError("cleanup failed") + + pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail) + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + + def start(self): + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + pool._cleanup_expired_connections.assert_called_once() + + +def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1") + assert parsed_non_dict["doc_id"] == "row-1" + assert parsed_non_dict["document_id"] == "row-1" + + parsed_none = vector._parse_metadata(None, "row-2") + assert parsed_none["doc_id"] == "row-2" + assert parsed_none["document_id"] == "row-2" + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_worker() + + class _BadQueue: + def get(self, timeout): + clickzetta_module.ClickzettaVector._shutdown = True + raise RuntimeError("queue failed") + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = _BadQueue() + clickzetta_module.ClickzettaVector._write_worker() + + +def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + cursor.execute.side_effect = [ + None, + RuntimeError("already has index with the same type cannot create inverted index"), + None, + ] + + vector._create_inverted_index(cursor) + + vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value)) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor_obj = MagicMock() + cursor_obj.__enter__.return_value = cursor_obj + cursor_obj.__exit__.return_value = None + connection.cursor.return_value = cursor_obj + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [SimpleNamespace(page_content="content", metadata="not-a-dict")], + [[0.1, 0.2]], + ["doc-1"], + 0, + 1, + 1, + ) + + +def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.enable_inverted_index = True + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_full_text("query") == [] + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", "[1,2,3]"), + ("seg-2", "content-2", None), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector.search_by_full_text("query") + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[1].metadata["doc_id"] == "seg-2" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py new file mode 100644 index 0000000000..9fea187615 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py @@ -0,0 +1,364 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_couchbase_modules(): + couchbase = types.ModuleType("couchbase") + couchbase_auth = types.ModuleType("couchbase.auth") + couchbase_cluster = types.ModuleType("couchbase.cluster") + couchbase_management = types.ModuleType("couchbase.management") + couchbase_management_search = types.ModuleType("couchbase.management.search") + couchbase_options = types.ModuleType("couchbase.options") + couchbase_vector = types.ModuleType("couchbase.vector_search") + couchbase_search = types.ModuleType("couchbase.search") + + class PasswordAuthenticator: + def __init__(self, user, password): + self.user = user + self.password = password + + class ClusterOptions: + def __init__(self, auth): + self.auth = auth + + class SearchOptions: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorQuery: + def __init__(self, field, vector, top_k): + self.field = field + self.vector = vector + self.top_k = top_k + + class VectorSearch: + @staticmethod + def from_vector_query(vector_query): + return {"vector_query": vector_query} + + class QueryStringQuery: + def __init__(self, query): + self.query = query + + class SearchRequest: + @staticmethod + def create(payload): + return {"payload": payload} + + class SearchIndex: + def __init__(self, name, params, source_name): + self.name = name + self.params = params + self.source_name = source_name + + class _QueryResult: + def __init__(self, rows=None): + self._rows = rows or [] + + def execute(self): + return self + + def __iter__(self): + return iter(self._rows) + + class _SearchIter: + def __init__(self, rows=None): + self._rows = rows or [] + + def rows(self): + return self._rows + + class _Collection: + def __init__(self): + self.upsert = MagicMock(return_value=True) + + class _SearchIndexManager: + def __init__(self): + self.upsert_index = MagicMock() + + class _Scope: + def __init__(self): + self._collection = _Collection() + self._search_index_manager = _SearchIndexManager() + self.search = MagicMock(return_value=_SearchIter()) + + def collection(self, _name): + return self._collection + + def search_indexes(self): + return self._search_index_manager + + class _CollectionManager: + def __init__(self): + self.create_collection = MagicMock() + self.drop_collection = MagicMock() + self.get_all_scopes = MagicMock(return_value=[]) + + class _Bucket: + def __init__(self): + self._scope = _Scope() + self._collections = _CollectionManager() + + def scope(self, _scope_name): + return self._scope + + def collections(self): + return self._collections + + class Cluster: + def __init__(self, connection_string, options): + self.connection_string = connection_string + self.options = options + self._bucket = _Bucket() + self.wait_until_ready = MagicMock() + self.query = MagicMock(return_value=_QueryResult()) + + def bucket(self, _name): + return self._bucket + + couchbase_auth.PasswordAuthenticator = PasswordAuthenticator + couchbase_cluster.Cluster = Cluster + couchbase_management_search.SearchIndex = SearchIndex + couchbase_options.ClusterOptions = ClusterOptions + couchbase_options.SearchOptions = SearchOptions + couchbase_vector.VectorQuery = VectorQuery + couchbase_vector.VectorSearch = VectorSearch + couchbase_search.QueryStringQuery = QueryStringQuery + couchbase_search.SearchRequest = SearchRequest + + couchbase.search = couchbase_search + couchbase.management = couchbase_management + + return { + "couchbase": couchbase, + "couchbase.auth": couchbase_auth, + "couchbase.cluster": couchbase_cluster, + "couchbase.management": couchbase_management, + "couchbase.management.search": couchbase_management_search, + "couchbase.options": couchbase_options, + "couchbase.vector_search": couchbase_vector, + "couchbase.search": couchbase_search, + } + + +@pytest.fixture +def couchbase_module(monkeypatch): + for name, module in _build_fake_couchbase_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.couchbase.couchbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.CouchbaseConfig( + connection_string="couchbase://localhost", + user="user", + password="pass", + bucket_name="bucket", + scope_name="scope", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("connection_string", "", "CONNECTION_STRING is required"), + ("user", "", "COUCHBASE_USER is required"), + ("password", "", "COUCHBASE_PASSWORD is required"), + ("bucket_name", "", "COUCHBASE_PASSWORD is required"), + ("scope_name", "", "COUCHBASE_SCOPE_NAME is required"), + ], +) +def test_couchbase_config_validation(couchbase_module, field, value, message): + values = _config(couchbase_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + couchbase_module.CouchbaseConfig.model_validate(values) + + +def test_init_sets_cluster_handles(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + assert vector._bucket_name == "bucket" + assert vector._scope_name == "scope" + vector._cluster.wait_until_ready.assert_called_once() + + +def test_create_and_create_collection_branches(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector) + vector._collection_name = "collection_1" + vector._client_config = _config(couchbase_module) + vector._scope_name = "scope" + vector._bucket_name = "bucket" + vector._bucket = MagicMock() + vector._scope = MagicMock() + vector._collection_exists = MagicMock(return_value=False) + vector.add_texts = MagicMock() + + monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c") + vector._create_collection = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock()) + + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(vector_length=2, uuid="uuid-1") + vector._bucket.collections().create_collection.assert_not_called() + + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._collection_exists = MagicMock(return_value=True) + vector._create_collection(vector_length=2, uuid="uuid-2") + vector._bucket.collections().create_collection.assert_not_called() + + vector._collection_exists = MagicMock(return_value=False) + vector._create_collection(vector_length=3, uuid="uuid-3") + + vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1") + vector._scope.search_indexes().upsert_index.assert_called_once() + search_index = vector._scope.search_indexes().upsert_index.call_args.args[0] + assert search_index.name == "collection_1_search" + assert ( + search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"] + == 3 + ) + couchbase_module.redis_client.set.assert_called_once() + + +def test_collection_exists_get_type_and_add_texts(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is True + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is False + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert vector._scope.collection("collection_1").upsert.call_count == 2 + assert vector.get_type() == couchbase_module.VectorType.COUCHBASE + + +def test_query_delete_helpers(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}])) + assert vector.text_exists("id-1") is True + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([])) + assert vector.text_exists("id-2") is False + + query_result = MagicMock() + query_result.execute.return_value = None + vector._cluster.query.return_value = query_result + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_document_id("id-1") + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._cluster.query.call_count >= 3 + + vector._cluster.query.side_effect = RuntimeError("delete failed") + vector.delete_by_ids(["id-3"]) + + +def test_search_methods_and_format_metadata(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9) + row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["document_id"] == "d-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector._scope.search.side_effect = RuntimeError("search error") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_vector([0.1], top_k=1) + + vector._scope.search.side_effect = None + row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3]) + docs = vector.search_by_full_text("hello", top_k=1) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == "x" + + vector._scope.search.side_effect = RuntimeError("full text failed") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_full_text("hello", top_k=1) + + assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2} + + +def test_delete_collection_and_factory(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + scopes = [ + SimpleNamespace(collections=[SimpleNamespace(name="other")]), + SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]), + ] + vector._bucket.collections().get_all_scopes.return_value = scopes + + vector.delete() + vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1") + + factory = couchbase_module.CouchbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + couchbase_module, + "current_app", + SimpleNamespace( + config={ + "COUCHBASE_CONNECTION_STRING": "couchbase://localhost", + "COUCHBASE_USER": "user", + "COUCHBASE_PASSWORD": "pass", + "COUCHBASE_BUCKET_NAME": "bucket", + "COUCHBASE_SCOPE_NAME": "scope", + } + ), + ) + + with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py new file mode 100644 index 0000000000..edd29a4649 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py @@ -0,0 +1,121 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0"}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_ja_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + + importlib.reload(base_module) + return importlib.reload(ja_module) + + +def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_not_called() + + +def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2, 0.3]], [{}]) + + vector._client.indices.create.assert_called_once() + kwargs = vector._client.indices.create.call_args.kwargs + assert kwargs["index"] == "test" + assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3 + elasticsearch_ja_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_ja_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_called_once() + + +def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch): + factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + elasticsearch_ja_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_HOST": "localhost", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + + with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py new file mode 100644 index 0000000000..9ecf0caa24 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py @@ -0,0 +1,405 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}}) + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), + delete=MagicMock(), + exists=MagicMock(return_value=False), + create=MagicMock(), + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + + return importlib.reload(module) + + +def _regular_config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "username": "elastic", + "password": "secret", + "verify_certs": False, + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +def _cloud_config(module, **overrides): + values = { + "use_cloud": True, + "cloud_url": "https://cloud.example:9243", + "api_key": "api-key", + "verify_certs": True, + "ca_certs": "/tmp/ca.pem", + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("values", "message"), + [ + ({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"), + ({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"), + ({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"), + ({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"), + ({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"), + ({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"), + ], +) +def test_elasticsearch_config_validation(elasticsearch_module, values, message): + with pytest.raises(ValidationError, match=message): + elasticsearch_module.ElasticSearchConfig.model_validate(values) + + +def test_init_client_cloud_configuration(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + result = vector._init_client(_cloud_config(elasticsearch_module)) + + assert result is client + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://cloud.example:9243"] + assert kwargs["api_key"] == "api-key" + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + +def test_init_client_regular_https_and_http_fallback(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client( + _regular_config( + elasticsearch_module, + host="https://es.example", + port=9443, + verify_certs=True, + ca_certs="/tmp/ca.pem", + ) + ) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://es.example:9443"] + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200)) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["http://es.internal:9200"] + assert "verify_certs" not in kwargs + + +def test_init_client_connection_failures(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + + client = MagicMock() + client.ping.return_value = False + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client): + with pytest.raises(ConnectionError, match="Failed to connect"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object( + elasticsearch_module, + "Elasticsearch", + side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"), + ): + with pytest.raises(ConnectionError, match="Vector database connection error"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")): + with pytest.raises(ConnectionError, match="initialization failed"): + vector._init_client(_regular_config(elasticsearch_module)) + + +def test_init_get_version_and_check_version(elasticsearch_module): + with ( + patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client, + patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version, + patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version, + ): + vector = elasticsearch_module.ElasticSearchVector( + "collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"] + ) + + init_client.assert_called_once() + get_version.assert_called_once() + check_version.assert_called_once() + assert vector._attributes == ["doc_id"] + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._client = MagicMock() + vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}} + assert vector._get_version() == "8.13.2" + + vector._version = "7.17.0" + with pytest.raises(ValueError, match="greater than 8.0.0"): + vector._check_version() + + vector._version = "8.0.0" + vector._check_version() + + +def test_crud_methods_and_get_type(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock()) + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection_1") + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._client.delete.call_count == 2 + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "d1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "d2") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1") + assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH + + +def test_search_by_vector_and_full_text(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.8, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-a", + elasticsearch_module.Field.VECTOR: [0.1], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-b", + elasticsearch_module.Field.VECTOR: [0.2], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + knn = vector._client.search.call_args.kwargs["knn"] + assert knn["k"] == 2 + assert knn["num_candidates"] == 3 + assert "filter" in knn + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "text-hit", + elasticsearch_module.Field.VECTOR: [0.3], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + query = vector._client.search.call_args.kwargs["query"] + assert "bool" in query + + +def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + mappings = vector._client.indices.create.call_args.kwargs["mappings"] + assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2 + elasticsearch_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + elasticsearch_module.redis_client.set.assert_called_once() + + +def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch): + factory = elasticsearch_module.ElasticSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": False, + "ELASTICSEARCH_HOST": "es-host", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + "ELASTICSEARCH_VERIFY_CERTS": False, + } + ), + ) + + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + assert result_1 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION" + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic", + "ELASTICSEARCH_API_KEY": "api-key", + "ELASTICSEARCH_VERIFY_CERTS": True, + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + assert result_2 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is True + assert cfg.cloud_url == "https://cloud.elastic" + assert dataset_without_index.index_struct is not None + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": None, + "ELASTICSEARCH_HOST": "fallback-host", + "ELASTICSEARCH_PORT": 9201, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert cfg.host == "fallback-host" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py new file mode 100644 index 0000000000..5d9e744ded --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py @@ -0,0 +1,371 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_hologres_modules(): + holo_module = types.ModuleType("holo_search_sdk") + holo_types_module = types.ModuleType("holo_search_sdk.types") + + holo_types_module.BaseQuantizationType = str + holo_types_module.DistanceType = str + holo_types_module.TokenizerType = str + + def _connect(**kwargs): + client = MagicMock() + client.kwargs = kwargs + client.connect = MagicMock() + client.check_table_exist = MagicMock(return_value=False) + client.open_table = MagicMock(return_value=MagicMock()) + client.execute = MagicMock(return_value=[]) + client.drop_table = MagicMock() + return client + + holo_module.connect = MagicMock(side_effect=_connect) + + return { + "holo_search_sdk": holo_module, + "holo_search_sdk.types": holo_types_module, + } + + +@pytest.fixture +def hologres_module(monkeypatch): + for name, module in _build_fake_hologres_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.hologres.hologres_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.HologresVectorConfig( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema_name="public", + tokenizer="jieba", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config HOLOGRES_HOST is required"), + ("database", "", "config HOLOGRES_DATABASE is required"), + ("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"), + ], +) +def test_hologres_config_validation(hologres_module, field, value, message): + values = _valid_config(hologres_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + hologres_module.HologresVectorConfig.model_validate(values) + + +def test_init_client_and_get_type(hologres_module): + vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module)) + + hologres_module.holo.connect.assert_called_once_with( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema="public", + ) + vector._client.connect.assert_called_once() + assert vector.table_name == "embedding_collection_one" + assert vector.get_type() == hologres_module.VectorType.HOLOGRES + + +def test_create_delegates_collection_creation_and_upsert(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result is None + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_returns_empty_for_empty_documents(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + assert vector.add_texts([], []) == [] + vector._client.open_table.assert_not_called() + + +def test_add_texts_batches_and_serializes_metadata(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + table = vector._client.open_table.return_value + documents = [ + Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"}) + for i in range(100) + ] + documents.append(SimpleNamespace(page_content="doc-100", metadata=None)) + embeddings = [[float(i)] for i in range(len(documents))] + + ids = vector.add_texts(documents, embeddings) + + assert ids[:2] == ["id-0", "id-1"] + assert ids[-1] == "" + assert len(ids) == 101 + assert vector._client.open_table.call_count == 2 + assert table.upsert_multi.call_count == 2 + first_call = table.upsert_multi.call_args_list[0].kwargs + second_call = table.upsert_multi.call_args_list[1].kwargs + assert first_call["index_column"] == "id" + assert first_call["column_names"] == ["id", "text", "meta", "embedding"] + assert first_call["update_columns"] == ["text", "meta", "embedding"] + assert len(first_call["values"]) == 100 + assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"} + assert second_call["values"][0][0] == "" + assert second_call["values"][0][2] == "{}" + + +def test_text_exists_handles_missing_and_present_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + vector._client.execute.return_value = [(1,)] + + assert vector.text_exists("seg-1") is False + assert vector.text_exists("seg-1") is True + vector._client.execute.assert_called_once() + + +def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +def test_delete_by_ids_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + vector.delete_by_ids([]) + vector._client.check_table_exist.assert_not_called() + + vector._client.check_table_exist.return_value = False + vector.delete_by_ids(["id-1"]) + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_ids(["id-1", "id-2"]) + vector._client.execute.assert_called_once() + + +def test_delete_by_metadata_field_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_called_once() + + +def test_search_by_vector_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_vector([0.1, 0.2]) == [] + + +def test_search_by_vector_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + query = MagicMock() + table.search_vector.return_value = query + query.select.return_value = query + query.limit.return_value = query + query.where.return_value = query + query.fetchall.return_value = [ + (0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'), + (0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["doc-1"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.8) + table.search_vector.assert_called_once() + query.where.assert_called_once() + + +def test_search_by_full_text_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_full_text("query") == [] + + +def test_search_by_full_text_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + search_query = MagicMock() + table.search_text.return_value = search_query + search_query.limit.return_value = search_query + search_query.where.return_value = search_query + search_query.fetchall.return_value = [ + ("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95), + ("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"]) + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.95) + assert docs[1].metadata["score"] == pytest.approx(0.7) + table.search_text.assert_called_once() + search_query.where.assert_called_once() + + +def test_delete_handles_existing_and_missing_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + + vector.delete() + vector._client.drop_table.assert_not_called() + + vector.delete() + vector._client.drop_table.assert_called_once_with(vector.table_name) + + +def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection(3) + + vector._client.check_table_exist.assert_not_called() + hologres_module.redis_client.set.assert_not_called() + + +def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, False, True] + table = vector._client.open_table.return_value + + vector._create_collection(3) + + vector._client.execute.assert_called_once() + table.set_vector_index.assert_called_once_with( + column="embedding", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + use_reorder=True, + ) + table.create_text_index.assert_called_once_with( + index_name="ft_idx_collection_one", + column="text", + tokenizer="jieba", + ) + hologres_module.redis_client.set.assert_called_once() + + +def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False] + [False] * 15 + + with pytest.raises(RuntimeError, match="was not ready after 30s"): + vector._create_collection(3) + + hologres_module.redis_client.set.assert_not_called() + + +def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch): + factory = hologres_module.HologresVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400) + + with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection" + generated_config = vector_cls.call_args_list[1].kwargs["config"] + assert generated_config.host == "127.0.0.1" + assert generated_config.database == "dify" + assert generated_config.access_key_id == "ak" + assert json.loads(dataset_without_index.index_struct) == { + "type": hologres_module.VectorType.HOLOGRES, + "vector_store": {"class_prefix": "generated_collection"}, + } diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py new file mode 100644 index 0000000000..9d23dfcf63 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py @@ -0,0 +1,243 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def huawei_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass") + + +def test_create_ssl_context(huawei_module): + ctx = huawei_module.create_ssl_context() + assert ctx.check_hostname is False + assert ctx.verify_mode == huawei_module.ssl.CERT_NONE + + +def test_huawei_config_validation_and_params(huawei_module): + with pytest.raises(ValidationError, match="HOSTS is required"): + huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""}) + + config = _config(huawei_module) + params = config.to_elasticsearch_params() + assert params["hosts"] == ["http://localhost:9200"] + assert params["basic_auth"] == ("user", "pass") + + config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None) + params = config.to_elasticsearch_params() + assert "basic_auth" not in params + + +def test_init_get_type_and_add_texts(huawei_module): + vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module)) + + assert vector._collection_name == "collection" + assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_crud_methods(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once_with(index="collection", id="id-1") + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection") + + +def test_search_by_vector_and_full_text(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + }, + { + "_score": 0.1, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-b", + huawei_module.Field.VECTOR: [0.2], + huawei_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + query_body = vector._client.search.call_args.kwargs["body"] + assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2 + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + huawei_module.Field.CONTENT_KEY: "text-hit", + huawei_module.Field.VECTOR: [0.3], + huawei_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + + +def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch): + class FakeDocument: + def __init__(self, page_content, vector, metadata): + self.page_content = page_content + self.vector = vector + self.metadata = None + + monkeypatch.setattr(huawei_module, "Document", FakeDocument) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + } + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5) + + assert docs == [] + + +def test_create_and_create_collection_paths(huawei_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock()) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + + kwargs = vector._client.indices.create.call_args.kwargs + mappings = kwargs["mappings"] + assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2 + assert kwargs["settings"] == {"index.vector": True} + huawei_module.redis_client.set.assert_called_once() + + +def test_huawei_factory_branches(huawei_module, monkeypatch): + factory = huawei_module.HuaweiCloudVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass") + + with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py new file mode 100644 index 0000000000..63338ca809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py @@ -0,0 +1,412 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_iris_module(): + iris = types.ModuleType("iris") + + def connect(**_kwargs): + conn = MagicMock() + conn.cursor.return_value = MagicMock() + return conn + + iris.connect = MagicMock(side_effect=connect) + return iris + + +@pytest.fixture +def iris_module(monkeypatch): + monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) + + import core.rag.datasource.vdb.iris.iris_vector as module + + reloaded = importlib.reload(module) + reloaded._pool_instance = None + return reloaded + + +def _config(module, **overrides): + values = { + "IRIS_HOST": "localhost", + "IRIS_SUPER_SERVER_PORT": 1972, + "IRIS_USER": "user", + "IRIS_PASSWORD": "pass", + "IRIS_DATABASE": "db", + "IRIS_SCHEMA": "schema", + "IRIS_CONNECTION_URL": "url", + "IRIS_MIN_CONNECTION": 1, + "IRIS_MAX_CONNECTION": 2, + "IRIS_TEXT_INDEX": True, + "IRIS_TEXT_INDEX_LANGUAGE": "en", + } + values.update(overrides) + return module.IrisVectorConfig.model_validate(values) + + +def test_get_iris_pool_singleton(iris_module): + iris_module._pool_instance = None + cfg = _config(iris_module) + + with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls: + pool_1 = iris_module.get_iris_pool(cfg) + pool_2 = iris_module.get_iris_pool(cfg) + + assert pool_1 == "pool" + assert pool_2 == "pool" + pool_cls.assert_called_once_with(cfg) + + +@pytest.fixture +def pool_with_min_max(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn: + pool = iris_module.IrisConnectionPool(cfg) + yield pool, create_conn + + +def test_pool_initialization_respects_min_max(pool_with_min_max): + pool, create_conn = pool_with_min_max + assert len(pool._pool) == 2 + assert create_conn.call_count == 2 + + +@pytest.fixture +def pool_for_get_connection(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_get_connection_returns_existing_and_increments(pool_for_get_connection): + pool = pool_for_get_connection + conn = MagicMock() + pool._pool = [conn] + pool._in_use = 0 + assert pool.get_connection() is conn + assert pool._in_use == 1 + + +def test_get_connection_creates_new_when_empty(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = 0 + pool._create_connection = MagicMock(return_value="new-conn") + assert pool.get_connection() == "new-conn" + + +def test_get_connection_raises_when_exhausted(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = pool._max_size + with pytest.raises(RuntimeError, match="exhausted"): + pool.get_connection() + + +@pytest.fixture +def pool_for_return_connection(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_return_connection_adds_healthy(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool.return_connection(conn) + assert pool._pool[-1] is conn + assert pool._in_use == 0 + + +def test_return_connection_replaces_bad(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + bad_conn = MagicMock() + bad_cursor = MagicMock() + bad_cursor.execute.side_effect = OSError("bad") + bad_conn.cursor.return_value = bad_cursor + replacement = MagicMock() + pool._create_connection = MagicMock(return_value=replacement) + pool.return_connection(bad_conn) + bad_conn.close.assert_called_once() + assert pool._pool[-1] is replacement + assert pool._in_use == 0 + + +def test_return_connection_ignores_none(pool_for_return_connection): + pool = pool_for_return_connection + before = len(pool._pool) + pool.return_connection(None) + assert len(pool._pool) == before + + +@pytest.fixture +def pool_for_schema_and_close(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool._pool = [conn] + return pool, conn, cursor + + +def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = {"cached_schema"} + pool.ensure_schema_exists("cached_schema") + cursor.execute.assert_not_called() + + +def test_ensure_schema_exists_creates_new(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (0,) + pool.ensure_schema_exists("new_schema") + assert "new_schema" in pool._schemas_initialized + assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list) + conn.commit.assert_called_once() + + +def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (1,) + pool.ensure_schema_exists("existing_schema") + conn.commit.assert_not_called() + + +def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.execute.side_effect = RuntimeError("schema failure") + with pytest.raises(RuntimeError, match="schema failure"): + pool.ensure_schema_exists("broken_schema") + conn.rollback.assert_called() + + +def test_close_all_closes_and_resets(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + conn_2 = MagicMock() + conn_2.close.side_effect = OSError("close fail") + pool._pool = [conn, conn_2] + pool._schemas_initialized = {"x"} + pool.close_all() + assert pool._pool == [] + assert pool._in_use == 0 + assert pool._schemas_initialized == set() + + +def test_iris_vector_init_get_cursor_and_create(iris_module): + pool = MagicMock() + pool.get_connection.return_value = MagicMock() + + with patch.object(iris_module, "get_iris_pool", return_value=pool): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + assert vector.table_name == "EMBEDDING_COLLECTION" + assert vector.schema == "schema" + assert vector.get_type() == iris_module.VectorType.IRIS + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + + with vector._get_cursor() as got_cursor: + assert got_cursor is cursor + conn.commit.assert_called_once() + vector.pool.return_connection.assert_called_with(conn) + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + with pytest.raises(RuntimeError, match="boom"): + with vector._get_cursor(): + raise RuntimeError("boom") + conn.rollback.assert_called_once() + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["id-1"]) + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"] + vector._create_collection.assert_called_once_with(2) + + +def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id") + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "generated-id"] + assert cursor.execute.call_count == 2 + + cursor.fetchone.return_value = (1,) + assert vector.text_exists("id-1") is True + cursor.fetchone.return_value = None + assert vector.text_exists("id-2") is False + + vector._get_cursor = MagicMock(side_effect=RuntimeError("db down")) + assert vector.text_exists("id-3") is False + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + before = cursor.execute.call_count + vector.delete_by_ids(["id-1", "id-2"]) + assert cursor.execute.call_count == before + 1 + + vector.delete_by_metadata_field("document_id", "doc-1") + assert "meta LIKE" in cursor.execute.call_args.args[0] + + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.9), + ("id-2", "text-2", '{"document_id":"d-2"}', 0.2), + ("id-x",), + ] + docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + +def test_iris_vector_full_text_search_paths(iris_module, monkeypatch): + cfg = _config(iris_module, IRIS_TEXT_INDEX=True) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", cfg) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + cursor.execute.side_effect = None + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.7), + ("id-2", "text-2", "{}", None), + ] + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.7) + assert docs[1].metadata["score"] == pytest.approx(0.0) + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("rank failed"), None] + cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)] + docs = vector.search_by_full_text("query", top_k=1) + assert len(docs) == 1 + assert cursor.execute.call_count == 2 + + cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector_like = iris_module.IrisVector("collection", cfg_like) + vector_like._get_cursor = _cursor_ctx + + fake_libs = types.ModuleType("libs") + fake_helper = types.ModuleType("libs.helper") + fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%") + monkeypatch.setitem(sys.modules, "libs", fake_libs) + monkeypatch.setitem(sys.modules, "libs.helper", fake_helper) + + cursor.reset_mock() + cursor.execute.side_effect = None + cursor.fetchall.return_value = [] + assert vector_like.search_by_full_text("100%", top_k=1) == [] + + +def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete() + assert "DROP TABLE" in cursor.execute.call_args.args[0] + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(iris_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_called_once() + + cursor.reset_mock() + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None)) + vector.pool.ensure_schema_exists = MagicMock() + vector._create_collection(3) + assert cursor.execute.call_count == 3 + iris_module.redis_client.set.assert_called_once() + + cursor.reset_mock() + vector.config.IRIS_TEXT_INDEX = False + vector._create_collection(3) + assert cursor.execute.call_count == 2 + + factory = iris_module.IrisVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972) + monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user") + monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass") + monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema") + monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url") + monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1) + monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en") + + with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py new file mode 100644 index 0000000000..34357d5907 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py @@ -0,0 +1,394 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearch_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = SimpleNamespace( + refresh=MagicMock(), + exists=MagicMock(return_value=False), + delete=MagicMock(), + create=MagicMock(), + ) + self.bulk = MagicMock(return_value={"errors": False, "items": []}) + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.delete_by_query = MagicMock() + self.get = MagicMock(return_value={"_id": "id"}) + self.exists = MagicMock(return_value=True) + + opensearch_helpers.BulkIndexError = BulkIndexError + opensearch_helpers.bulk = MagicMock() + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.helpers = opensearch_helpers + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearch_helpers, + } + + +@pytest.fixture +def lindorm_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.lindorm.lindorm_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.LindormVectorStoreConfig( + hosts="http://localhost:9200", + username="user", + password="pass", + using_ugc=False, + request_timeout=3.0, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("hosts", None, "config URL is required"), + ("username", None, "config USERNAME is required"), + ("password", None, "config PASSWORD is required"), + ], +) +def test_lindorm_config_validation(lindorm_module, field, value, message): + values = _config(lindorm_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + lindorm_module.LindormVectorStoreConfig.model_validate(values) + + +def test_to_opensearch_params_and_init(lindorm_module): + cfg = _config(lindorm_module) + params = cfg.to_opensearch_params() + + assert params["hosts"] == "http://localhost:9200" + assert params["http_auth"] == ("user", "pass") + + vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False) + assert vector._collection_name == "collection" + assert vector.get_type() == lindorm_module.VectorType.LINDORM + + with pytest.raises(ValueError, match="routing_value"): + lindorm_module.LindormVectorStore("c", cfg, using_ugc=True) + + vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE") + assert vector_ugc._routing == "route" + + +def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock()) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + + vector.add_texts(docs, embeddings, batch_size=2, timeout=9) + + assert vector._client.bulk.call_count == 2 + actions = vector._client.bulk.call_args_list[0].args[0] + assert actions[0]["index"]["routing"] == "route" + assert actions[1][lindorm_module.ROUTING_FIELD] == "route" + vector.refresh() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_add_texts_error_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]} + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + vector._client.bulk.side_effect = RuntimeError("bulk failed") + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + +def test_metadata_lookup_and_delete_by_metadata(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + query = vector._client.search.call_args.kwargs["body"] + must_conditions = query["query"]["bool"]["must"] + assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions) + + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"]) + + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-2") + vector.delete_by_ids.assert_not_called() + + +def test_delete_by_ids_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + vector.delete_by_ids([]) + vector._client.indices.exists.assert_not_called() + + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["id-1"]) + + vector._client.indices.exists.return_value = True + vector._client.exists.side_effect = [True, False] + lindorm_module.helpers.bulk.reset_mock() + vector.delete_by_ids(["id-1", "id-2"]) + lindorm_module.helpers.bulk.assert_called_once() + actions = lindorm_module.helpers.bulk.call_args.args[1] + assert len(actions) == 1 + assert actions[0]["routing"] == "route" + + lindorm_module.helpers.bulk.reset_mock() + lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError( + errors=[ + {"delete": {"status": 404, "_id": "id-404"}}, + {"delete": {"status": 500, "_id": "id-500"}}, + ] + ) + vector._client.exists.side_effect = [True] + vector.delete_by_ids(["id-1"]) + + +def test_delete_and_text_exists(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.delete() + vector._client.delete_by_query.assert_called_once() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.indices.exists.return_value = True + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60}) + + vector._client.indices.delete.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete() + vector._client.indices.delete.assert_not_called() + + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("missing") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validation_and_success(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + with pytest.raises(ValueError, match="should be a list"): + vector.search_by_vector("bad") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, "bad"]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-b", + lindorm_module.Field.VECTOR: [0.2], + lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + call_kwargs = vector._client.search.call_args.kwargs + query = call_kwargs["body"] + assert "ext" in query + assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"] + assert call_kwargs["params"]["routing"] == "route" + + vector._client.search.side_effect = RuntimeError("search failed") + with pytest.raises(RuntimeError, match="search failed"): + vector.search_by_vector([0.1]) + + +def test_search_by_full_text_success_and_error(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1"}, + } + } + ] + } + } + + docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] + + vector._client.search.side_effect = RuntimeError("full text failed") + with pytest.raises(RuntimeError, match="full text failed"): + vector.search_by_full_text("hello") + + +def test_create_collection_paths(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + + with pytest.raises(ValueError, match="cannot be empty"): + vector.create_collection([]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"}) + vector._client.indices.create.assert_called_once() + body = vector._client.indices.create.call_args.kwargs["body"] + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf" + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine" + + vector._client.indices.create.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + +def test_lindorm_factory_branches(lindorm_module, monkeypatch): + factory = lindorm_module.LindormVectorStoreFactory() + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2") + monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={}) + embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3]) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None) + with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"): + factory.init_vector(dataset, attributes=[], embeddings=embeddings) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + + dataset_existing_plain = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False}, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings) + assert result == "vector" + assert store_cls.call_args.args[0] == "existing" + + dataset_existing_ugc = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={ + "vector_store": {"class_prefix": "ROUTING"}, + "using_ugc": True, + "dimension": 1536, + "index_type": "hnsw", + "distance_type": "l2", + }, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "ROUTING" + + dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={}) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "auto_collection" + assert dataset_new.index_struct is not None + + dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={}) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "auto_collection" + assert store_cls.call_args.kwargs["routing_value"] is None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py new file mode 100644 index 0000000000..55e7b9112e --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py @@ -0,0 +1,252 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_mo_vector_modules(): + mo_vector = types.ModuleType("mo_vector") + mo_vector.__path__ = [] + mo_vector_client = types.ModuleType("mo_vector.client") + + class MoVectorClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_full_text_index = MagicMock() + self.insert = MagicMock() + self.get = MagicMock(return_value=[]) + self.delete = MagicMock() + self.query_by_metadata = MagicMock(return_value=[]) + self.query = MagicMock(return_value=[]) + self.full_text_query = MagicMock(return_value=[]) + + mo_vector_client.MoVectorClient = MoVectorClient + mo_vector.client = mo_vector_client + return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client} + + +@pytest.fixture +def matrixone_module(monkeypatch): + for name, module in _build_fake_mo_vector_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.matrixone.matrixone_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.MatrixoneConfig( + host="localhost", + port=6001, + user="dump", + password="111", + database="dify", + metric="l2", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config host is required"), + ("port", 0, "config port is required"), + ("user", "", "config user is required"), + ("password", "", "config password is required"), + ("database", "", "config database is required"), + ], +) +def test_matrixone_config_validation(matrixone_module, field, value, message): + values = _valid_config(matrixone_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + matrixone_module.MatrixoneConfig.model_validate(values) + + +def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client.kwargs["table_name"] == "collection_1" + client.create_full_text_index.assert_called_once() + matrixone_module.redis_client.set.assert_called_once() + + +def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + client.create_full_text_index.assert_not_called() + matrixone_module.redis_client.set.assert_not_called() + + +def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = None + fake_client = MagicMock() + fake_client.get.return_value = [{"id": "seg-1"}] + vector._get_client = MagicMock(return_value=fake_client) + + exists = vector.text_exists("seg-1") + + assert exists is True + vector._get_client.assert_called_once_with(None, False) + + +def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.full_text_query.return_value = [ + SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}} + + +def test_get_type_and_create_delegate_to_add_texts(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + fake_client = MagicMock() + vector._get_client = MagicMock(return_value=fake_client) + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "matrixone" + assert result == ["seg-1"] + vector._get_client.assert_called_once_with(2, True) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + failing_client = MagicMock() + failing_client.create_full_text_index.side_effect = RuntimeError("boom") + monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client)) + + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client is failing_client + matrixone_module.redis_client.set.assert_not_called() + + +def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + # For current prod code, only docs with metadata get ids, so only two ids + assert ids == ["doc-a", "generated-uuid"] + vector.client.insert.assert_called_once() + insert_kwargs = vector.client.insert.call_args.kwargs + # All lists passed to insert should be the same length + texts = insert_kwargs["texts"] + embeddings = insert_kwargs["embeddings"] + metadatas = insert_kwargs["metadatas"] + ids_insert = insert_kwargs["ids"] + assert len(texts) == len(embeddings) == len(metadatas) == len(docs) + # ids may be shorter than docs for current prod code, but should match number of docs with metadata + assert ids_insert == ["doc-a", "generated-uuid"] + + +def test_delete_and_metadata_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")] + + vector.delete_by_ids([]) + vector.client.delete.assert_not_called() + + vector.delete_by_ids(["seg-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + vector.delete() + + assert ids == ["seg-1", "seg-2"] + assert vector.client.delete.call_count == 3 + + +def test_search_by_vector_builds_documents(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query.return_value = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}), + ] + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"]) + + assert len(docs) == 2 + assert docs[0].page_content == "doc-a" + assert docs[1].metadata["doc_id"] == "2" + assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}} + + +def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch): + factory = matrixone_module.MatrixoneVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001) + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2") + + with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index fb2ddfe162..2ac2c40d38 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -1,18 +1,414 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + import pytest from pydantic import ValidationError -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig +from core.rag.models.document import Document -def test_default_value(): +def _build_fake_pymilvus_modules(): + pymilvus = types.ModuleType("pymilvus") + pymilvus.__path__ = [] + pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client") + pymilvus_orm = types.ModuleType("pymilvus.orm") + pymilvus_orm.__path__ = [] + pymilvus_orm_types = types.ModuleType("pymilvus.orm.types") + + class MilvusError(Exception): + pass + + class MilvusClient: + def __init__(self, **kwargs): + self.init_kwargs = kwargs + self.has_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock( + return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]} + ) + self.get_server_version = MagicMock(return_value="2.5.0") + self.insert = MagicMock(return_value=[1]) + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.drop_collection = MagicMock() + self.search = MagicMock(return_value=[[]]) + self.create_collection = MagicMock() + + class IndexParams: + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + class DataType: + JSON = "JSON" + VARCHAR = "VARCHAR" + INT64 = "INT64" + SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR" + FLOAT_VECTOR = "FLOAT_VECTOR" + + class FieldSchema: + def __init__(self, name, dtype, **kwargs): + self.name = name + self.dtype = dtype + self.kwargs = kwargs + + class CollectionSchema: + def __init__(self, fields): + self.fields = fields + self.functions = [] + + def add_function(self, func): + self.functions.append(func) + + class FunctionType: + BM25 = "BM25" + + class Function: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def infer_dtype_bydata(_value): + return DataType.FLOAT_VECTOR + + pymilvus.MilvusException = MilvusError + pymilvus.MilvusClient = MilvusClient + pymilvus.IndexParams = IndexParams + pymilvus.CollectionSchema = CollectionSchema + pymilvus.DataType = DataType + pymilvus.FieldSchema = FieldSchema + pymilvus.Function = Function + pymilvus.FunctionType = FunctionType + pymilvus_milvus_client.IndexParams = IndexParams + pymilvus_orm.types = pymilvus_orm_types + pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata + + # Attach submodules for dotted imports + pymilvus.milvus_client = pymilvus_milvus_client + pymilvus.orm = pymilvus_orm + + return { + "pymilvus": pymilvus, + "pymilvus.milvus_client": pymilvus_milvus_client, + "pymilvus.orm": pymilvus_orm, + "pymilvus.orm.types": pymilvus_orm_types, + } + + +@pytest.fixture +def milvus_module(monkeypatch): + for name, module in _build_fake_pymilvus_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.milvus.milvus_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "uri": "http://localhost:19530", + "user": "root", + "password": "Milvus", + "database": "default", + "enable_hybrid_search": False, + "analyzer_params": None, + } + values.update(overrides) + return module.MilvusConfig.model_validate(values) + + +def test_config_validation_and_defaults(milvus_module): valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig.model_validate(config) + milvus_module.MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig.model_validate(valid_config) + config = milvus_module.MilvusConfig.model_validate(valid_config) assert config.database == "default" + + token_config = milvus_module.MilvusConfig.model_validate( + {"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"} + ) + assert token_config.token == "token-value" + + +def test_config_to_milvus_params(milvus_module): + config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + + params = config.to_milvus_params() + + assert params["uri"] == "http://localhost:19530" + assert params["db_name"] == "default" + assert params["analyzer_params"] == '{"tokenizer":"standard"}' + + +def test_init_client_supports_token_and_user_password(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + token_client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"} + + user_client = vector._init_client(_config(milvus_module)) + assert user_client.init_kwargs["uri"] == "http://localhost:19530" + assert user_client.init_kwargs["user"] == "root" + assert user_client.init_kwargs["password"] == "Milvus" + + +def test_init_loads_fields_when_collection_exists(milvus_module): + client = milvus_module.MilvusClient(uri="http://localhost:19530") + client.has_collection.return_value = True + client.describe_collection.return_value = { + "fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}] + } + + with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client): + with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False): + vector = milvus_module.MilvusVector("collection_1", _config(milvus_module)) + + assert "id" not in vector._fields + assert "content" in vector._fields + + +def test_load_collection_fields_from_argument_and_remote(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + vector._collection_name = "collection_1" + vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]} + + vector._load_collection_fields(["id", "metadata"]) + assert vector._fields == ["metadata"] + + vector._load_collection_fields() + assert vector._fields == ["content"] + + +def test_check_hybrid_search_support_branches(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + + vector._client_config = SimpleNamespace(enable_hybrid_search=False) + assert vector._check_hybrid_search_support() is False + + vector._client_config = SimpleNamespace(enable_hybrid_search=True) + vector._client.get_server_version.return_value = "Zilliz Cloud 2.4" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.5.1" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.4.9" + assert vector._check_hybrid_search_support() is False + + vector._client.get_server_version.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_get_type_and_create_delegate(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [SimpleNamespace(page_content="hello", metadata=None)] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "milvus" + vector.create_collection.assert_called_once() + create_args = vector.create_collection.call_args.args + assert create_args[0] == [[0.1, 0.2]] + assert create_args[1] == [{}] + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_and_raises_milvus_exception(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.insert.side_effect = [["id-1"], ["id-2"]] + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + ids = vector.add_texts(docs, embeddings) + assert ids == ["id-1", "id-2"] + assert vector._client.insert.call_count == 2 + + vector._client.insert.side_effect = milvus_module.MilvusException("insert failed") + with pytest.raises(milvus_module.MilvusException): + vector.add_texts([Document(page_content="x", metadata={})], [[0.1]]) + + +def test_get_ids_and_delete_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.query.return_value = [{"id": 1}, {"id": 2}] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2] + vector._client.query.return_value = [] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector._client.has_collection.return_value = True + vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102]) + + vector._client.delete.reset_mock() + vector._client.query.return_value = [{"id": 11}, {"id": 12}] + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12]) + + vector._client.has_collection.return_value = True + vector.delete() + vector._client.drop_collection.assert_called_once_with("collection_1", None) + + +def test_text_exists_and_field_exists(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._fields = ["content", "metadata"] + vector._client = MagicMock() + vector._client.has_collection.return_value = False + assert vector.text_exists("doc-1") is False + + vector._client.has_collection.return_value = True + vector._client.query.return_value = [{"id": 1}] + assert vector.text_exists("doc-1") is True + vector._client.query.return_value = [] + assert vector.text_exists("doc-1") is False + assert vector.field_exists("content") is True + assert vector.field_exists("unknown") is False + + +def test_process_search_results_and_search_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._fields = ["content", "metadata", "sparse_vector"] + + processed = vector._process_search_results( + [ + [ + {"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9}, + {"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2}, + ] + ], + [milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY], + score_threshold=0.5, + ) + assert len(processed) == 1 + assert processed[0].metadata["score"] == 0.9 + + vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1) + assert docs == ["doc"] + assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]' + + vector._hybrid_search_enabled = False + assert vector.search_by_full_text("query") == [] + + vector._hybrid_search_enabled = True + vector._fields = [] + assert vector.search_by_full_text("query") == [] + + vector._fields = [milvus_module.Field.SPARSE_VECTOR] + vector._process_search_results = MagicMock(return_value=["full-text-doc"]) + full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2) + assert full_text_docs == ["full-text-doc"] + assert "document_id" in vector._client.search.call_args.kwargs["filter"] + + +def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client_config = _config(milvus_module) + vector._hybrid_search_enabled = False + vector._client = MagicMock() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.has_collection.return_value = True + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + milvus_module.redis_client.set.assert_called() + + +def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client = MagicMock() + vector._client.has_collection.return_value = False + vector._load_collection_fields = MagicMock() + + vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + vector._hybrid_search_enabled = True + vector.create_collection( + embeddings=[[0.1, 0.2]], + metadatas=[{"doc_id": "1"}], + index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}}, + ) + + call_kwargs = vector._client.create_collection.call_args.kwargs + schema = call_kwargs["schema"] + index_params_obj = call_kwargs["index_params"] + field_names = [f.name for f in schema.fields] + + assert milvus_module.Field.SPARSE_VECTOR in field_names + assert len(schema.functions) == 1 + assert len(index_params_obj.indexes) == 2 + assert call_kwargs["consistency_level"] == "Session" + + +def test_factory_initializes_milvus_vector(milvus_module, monkeypatch): + factory = milvus_module.MilvusVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}') + + with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py new file mode 100644 index 0000000000..a75ba82238 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py @@ -0,0 +1,230 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickhouse_connect_module(): + clickhouse_connect = types.ModuleType("clickhouse_connect") + + class QueryResult: + def __init__(self, rows=None, named_rows=None): + self.row_count = len(rows or []) + self.result_rows = rows or [] + self._named_rows = named_rows or [] + + def named_results(self): + return self._named_rows + + class Client: + def __init__(self): + self.command = MagicMock() + self.query = MagicMock(return_value=QueryResult()) + + client = Client() + + def get_client(**_kwargs): + return client + + clickhouse_connect.get_client = get_client + clickhouse_connect.QueryResult = QueryResult + clickhouse_connect._fake_client = client + return clickhouse_connect + + +@pytest.fixture +def myscale_module(monkeypatch): + fake_module = _build_fake_clickhouse_connect_module() + monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) + + import core.rag.datasource.vdb.myscale.myscale_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="", + ) + + +def test_escape_str_replaces_backslash_and_quote(myscale_module): + escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special") + assert escaped == "text with special" + + +def test_search_raises_for_invalid_top_k(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0) + + +def test_search_builds_where_clause_for_cosine_threshold(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__( + named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}] + ) + + docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2) + + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "WHERE dist < 0.8" in sql + + +def test_delete_by_ids_short_circuits_on_empty_list(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete_by_ids([]) + vector._client.command.assert_not_called() + + +def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch): + factory = myscale_module.MyScaleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123) + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "") + + with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +def test_init_and_get_type_set_expected_defaults(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + + assert vector.get_type() == "myscale" + assert vector._vec_order == myscale_module.SortOrder.ASC + vector._client.command.assert_called_with("SET allow_experimental_object_type=1") + + +def test_create_calls_create_collection_and_add_texts(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once() + + +def test_create_collection_builds_expected_sql(myscale_module): + config = myscale_module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="tokenizer=unicode", + ) + vector = myscale_module.MyScaleVector("collection_1", config) + vector._client.command.reset_mock() + + vector._create_collection(3) + + assert vector._client.command.call_count == 2 + sql = vector._client.command.call_args_list[1].args[0] + assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql + assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql + assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql + + +def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="text-2", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="text-3", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + sql = vector._client.command.call_args.args[0] + assert "INSERT INTO dify.collection_1" in sql + assert "te xt 1" in sql + + +def test_text_exists_and_metadata_operations(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)]) + + assert vector.text_exists("id-1") is True + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._client.command.call_count >= 2 + + +def test_search_delegation_methods(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._search = MagicMock(return_value=["result"]) + + result_vector = vector.search_by_vector([0.1, 0.2], top_k=2) + result_text = vector.search_by_full_text("hello", top_k=2) + + assert result_vector == ["result"] + assert result_text == ["result"] + assert vector._search.call_count == 2 + + +def test_search_with_document_filter_and_exception(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace( + named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}] + ) + + docs = vector._search( + "distance(vector, [0.1])", + myscale_module.SortOrder.ASC, + top_k=2, + document_ids_filter=["doc-1", "doc-2"], + ) + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql + + vector._client.query.side_effect = RuntimeError("boom") + assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == [] + + +def test_delete_drops_table(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete() + + vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py new file mode 100644 index 0000000000..27d8198ec0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py @@ -0,0 +1,553 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.exc import SQLAlchemyError + +from core.rag.models.document import Document + + +def _build_fake_pyobvector_module(): + pyobvector = types.ModuleType("pyobvector") + + class VECTOR: + def __init__(self, dim): + self.dim = dim + + def l2_distance(*_args, **_kwargs): + return "l2" + + def cosine_distance(*_args, **_kwargs): + return "cosine" + + def inner_product(*_args, **_kwargs): + return "inner_product" + + class ObVecClient: + def __init__(self, **_kwargs): + self.metadata_obj = SimpleNamespace(tables={}) + self.engine = MagicMock() + self.check_table_exists = MagicMock(return_value=False) + self.perform_raw_text_sql = MagicMock() + self.prepare_index_params = MagicMock() + self.create_table_with_index_params = MagicMock() + self.refresh_metadata = MagicMock() + self.insert = MagicMock() + self.refresh_index = MagicMock() + self.get = MagicMock() + self.delete = MagicMock() + self.set_ob_hnsw_ef_search = MagicMock() + self.ann_search = MagicMock(return_value=[]) + self.drop_table_if_exist = MagicMock() + + pyobvector.VECTOR = VECTOR + pyobvector.ObVecClient = ObVecClient + pyobvector.l2_distance = l2_distance + pyobvector.cosine_distance = cosine_distance + pyobvector.inner_product = inner_product + return pyobvector + + +@pytest.fixture +def oceanbase_module(monkeypatch): + monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) + + import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.OceanBaseVectorConfig( + host="127.0.0.1", + port=2881, + user="root", + password="secret", + database="test", + enable_hybrid_search=True, + batch_size=10, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OCEANBASE_VECTOR_HOST is required"), + ("port", 0, "config OCEANBASE_VECTOR_PORT is required"), + ("user", "", "config OCEANBASE_VECTOR_USER is required"), + ("database", "", "config OCEANBASE_VECTOR_DATABASE is required"), + ], +) +def test_oceanbase_config_validation(oceanbase_module, field, value, message): + values = _config(oceanbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oceanbase_module.OceanBaseVectorConfig.model_validate(values) + + +def test_init_rejects_invalid_collection_name(oceanbase_module): + with pytest.raises(ValueError, match="Invalid collection name"): + oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module)) + + +def test_distance_to_score_for_supported_metrics(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="l2") + assert vector._distance_to_score(3.0) == pytest.approx(0.25) + + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._distance_to_score(0.2) == pytest.approx(0.8) + + vector._config = SimpleNamespace(metric_type="inner_product") + assert vector._distance_to_score(-0.2) == pytest.approx(0.2) + + +def test_get_distance_func_raises_for_unknown_metric(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="manhattan") + + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._get_distance_func() + + +def test_process_search_results_handles_json_and_score_threshold(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + rows = [ + ("doc-1", '{"doc_id":"1"}', 0.9), + ("doc-2", "not-json", 0.8), + ("doc-3", {"doc_id": "3"}, 0.3), + ] + + docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank") + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["rank"] == 0.9 + assert docs[1].metadata["rank"] == 0.8 + + +def test_search_by_vector_validates_document_id_format(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + + with pytest.raises(ValueError, match="Invalid document ID format"): + vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"]) + + +def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._hybrid_search_enabled = False + vector._collection_name = "collection_1" + + assert vector.search_by_full_text("query") == [] + + +def test_check_hybrid_search_support_uses_version_comment(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client = MagicMock() + cursor = MagicMock() + cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",) + vector._client.perform_raw_text_sql.return_value = cursor + + assert vector._check_hybrid_search_support() is True + + cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",) + assert vector._check_hybrid_search_support() is False + + +def test_init_get_type_and_field_loading(oceanbase_module): + config = _config(oceanbase_module) + config.enable_hybrid_search = False + + table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")]) + fake_client = oceanbase_module.ObVecClient() + fake_client.check_table_exists.return_value = True + fake_client.metadata_obj.tables = {"collection_1": table} + + with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client): + vector = oceanbase_module.OceanBaseVector("collection_1", config) + + assert vector.get_type() == "oceanbase" + assert vector.field_exists("text") is True + + +def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._fields = [] + vector._client = MagicMock() + vector._client.metadata_obj.tables = {} + + vector._load_collection_fields() + assert vector._fields == [] + + vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))} + vector._load_collection_fields() + assert vector._fields == [] + + +def test_create_delegates_to_collection_and_insert(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector._vec_dim == 2 + vector._create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = False + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection() + vector._client.check_table_exists.assert_not_called() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.check_table_exists.return_value = True + vector._create_collection() + vector.delete.assert_not_called() + + +def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 3 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + None, + None, + ] + index_params = MagicMock() + vector._client.prepare_index_params.return_value = index_params + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._create_collection() + + vector.delete.assert_called_once() + vector._client.create_table_with_index_params.assert_called_once() + index_params.add_index.assert_called_once() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + oceanbase_module.redis_client.set.assert_called_once() + + +def test_create_collection_error_paths(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.return_value = [] + with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "0"]], + RuntimeError("no privilege"), + ] + with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]] + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid") + with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"): + vector._create_collection() + + +def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + RuntimeError("fulltext failed"), + ] + with pytest.raises(Exception, match="Failed to add fulltext index"): + vector._create_collection() + + vector._hybrid_search_enabled = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + SQLAlchemyError("metadata index failed"), + ] + vector._create_collection() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + + +def test_check_hybrid_search_support_false_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=False) + vector._client = MagicMock() + assert vector._check_hybrid_search_support() is False + + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_add_texts_batches_refresh_and_exceptions(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2) + vector._client = MagicMock() + vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + + vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert vector._client.insert.call_count == 2 + vector._client.refresh_index.assert_called_once() + + vector._client.insert.reset_mock() + vector._client.refresh_index.reset_mock() + vector._client.insert.side_effect = RuntimeError("insert failed") + with pytest.raises(Exception, match="Failed to insert batch"): + vector.add_texts([docs[0]], [[0.1]]) + + vector._client.insert.side_effect = None + vector._client.insert.return_value = None + vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed") + vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1) + vector._get_uuids.return_value = ["id-1"] + vector.add_texts([docs[0]], [[0.1]]) + + +def test_text_exists_and_delete_by_ids(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.get.return_value = SimpleNamespace(rowcount=1) + assert vector.text_exists("id-1") is True + + vector._client.get.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to check text existence"): + vector.text_exists("id-1") + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + + vector._client.delete.side_effect = None + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once() + + vector._client.delete.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete documents"): + vector.delete_by_ids(["id-1"]) + + +def test_get_ids_and_delete_by_metadata_field(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + execute_result = [("id-1",), ("id-2",)] + + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value = execute_result + vector._client.engine.connect.return_value = conn + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + + with pytest.raises(Exception, match="Failed to query documents by metadata field"): + vector.get_ids_by_metadata_field("bad key!", "doc-1") + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_not_called() + + +def test_search_by_full_text_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hybrid_search_enabled = True + vector.field_exists = MagicMock(return_value=False) + + assert vector.search_by_full_text("query") == [] + + vector.field_exists.return_value = True + vector._client = MagicMock() + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = tx + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)] + vector._client.engine.connect.return_value = conn + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + with pytest.raises(Exception, match="Full-text search failed"): + vector.search_by_full_text("query", top_k=0) + + +def test_search_by_vector_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector( + [0.1, 0.2], + ef_search=10, + top_k=3, + score_threshold=0.1, + document_ids_filter=["good_id"], + ) + assert docs == ["doc"] + vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10) + + with pytest.raises(ValueError, match="Invalid score_threshold parameter"): + vector.search_by_vector([0.1], score_threshold="x") + + vector._client.ann_search.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Vector search failed"): + vector.search_by_vector([0.1], score_threshold=0.1) + + +def test_get_distance_func_and_distance_to_score_errors(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._get_distance_func() is oceanbase_module.cosine_distance + + vector._config = SimpleNamespace(metric_type="unknown") + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._distance_to_score(0.1) + + +def test_delete_success_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector.delete() + vector._client.drop_table_if_exist.assert_called_once_with("collection_1") + + vector._client.drop_table_if_exist.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete collection"): + vector.delete() + + +def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch): + factory = oceanbase_module.OceanBaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000) + + with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].args[0] == "existing_collection" + assert vector_cls.call_args_list[1].args[0] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py new file mode 100644 index 0000000000..6641dbe4a0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py @@ -0,0 +1,400 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def opengauss_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opengauss.opengauss as module + + return importlib.reload(module) + + +def _config(module, *, enable_pq=False): + return module.OpenGaussConfig( + host="localhost", + port=6600, + user="postgres", + password="password", + database="dify", + min_connection=1, + max_connection=5, + enable_pq=enable_pq, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENGAUSS_HOST is required"), + ("port", 0, "config OPENGAUSS_PORT is required"), + ("user", "", "config OPENGAUSS_USER is required"), + ("password", "", "config OPENGAUSS_PASSWORD is required"), + ("database", "", "config OPENGAUSS_DATABASE is required"), + ("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"), + ("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"), + ], +) +def test_opengauss_config_validation(opengauss_module, field, value, message): + values = _config(opengauss_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module): + values = _config(opengauss_module).model_dump() + values["min_connection"] = 6 + values["max_connection"] = 5 + + with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + + assert vector.table_name == "embedding_collection_1" + assert vector.get_type() == "opengauss" + assert vector.pool is pool + + +def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("enable_pq=on" in sql for sql in executed_sql) + assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql) + opengauss_module.redis_client.set.assert_called_once() + + +def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(3072) + + cursor.execute.assert_not_called() + opengauss_module.redis_client.set.assert_called_once() + + +def test_search_by_vector_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + +def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector.delete_by_ids([]) + + vector._get_cursor.assert_not_called() + + +def test_get_cursor_closes_commits_and_returns_connection(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + pool = MagicMock() + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + vector.pool = pool + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_calls_collection_insert_and_index(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + vector._create_index = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + vector._create_index.assert_called_once_with(2) + + +def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector._create_index(1536) + + vector._get_cursor.assert_not_called() + opengauss_module.redis_client.set.assert_not_called() + + +def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql) + + +def test_add_texts_uses_execute_values(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + opengauss_module.psycopg2.extras.execute_values.reset_mock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + docs = [ + Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}), + SimpleNamespace(page_content="text-2", metadata=None), + ] + monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid") + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["seg-1"] + opengauss_module.psycopg2.extras.execute_values.assert_called_once() + + +def test_text_exists_and_get_by_ids(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("seg-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("seg-1") is True + docs = vector.get_by_ids(["seg-1", "seg-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + +def test_delete_and_metadata_field_queries(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + vector.delete_by_ids(["seg-1", "seg-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql) + assert any("meta->>%s = %s" in query for query in sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql) + + +def test_search_by_vector_and_full_text(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.6), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_search_by_full_text_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=0) + + +def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(1536) + cursor.execute.assert_not_called() + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(1536) + cursor.execute.assert_called_once() + opengauss_module.redis_client.set.assert_called_once() + + +def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch): + factory = opengauss_module.OpenGaussFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False) + + with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py new file mode 100644 index 0000000000..1030158dd1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py @@ -0,0 +1,360 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "basic", + "user": "admin", + "password": "secret", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENSEARCH_HOST is required"), + ("port", 0, "config OPENSEARCH_PORT is required"), + ], +) +def test_config_validation_required_fields(opensearch_module, field, value, message): + values = _config(opensearch_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_config_validation_for_aws_auth_and_https_fields(opensearch_module): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "aws_managed_iam", + "user": "admin", + "password": "secret", + } + with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"): + opensearch_module.OpenSearchConfig.model_validate(values) + + values = _config(opensearch_module).model_dump() + values["OPENSEARCH_SECURE"] = False + values["OPENSEARCH_VERIFY_CERTS"] = True + with pytest.raises(ValidationError, match="verify_certs=True requires secure"): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-east-1", + aws_service="es", + ) + auth = config.create_aws_managed_iam_auth() + + assert auth.credentials == "creds" + assert auth.region == "us-east-1" + assert auth.service == "es" + + +def test_to_opensearch_params_supports_basic_and_aws(opensearch_module): + basic_params = _config(opensearch_module).to_opensearch_params() + assert basic_params["http_auth"] == ("admin", "secret") + + aws_config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-west-2", + aws_service="es", + ) + with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"): + aws_params = aws_config.to_opensearch_params() + + assert aws_params["http_auth"] == "iam-auth" + + +def test_init_and_create_delegate_calls(opensearch_module): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "opensearch" + vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es")) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id")) + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.1], [0.2]]) + actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert len(actions) == 2 + assert all("_id" in action for action in actions) + + vector._client_config.aws_service = "aoss" + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.3], [0.4]]) + aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert all("_id" not in action for action in aoss_actions) + + +def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector._client.search.return_value = {"hits": {"hits": []}} + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + +def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + opensearch_module.helpers.bulk.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["doc-1"]) + opensearch_module.helpers.bulk.assert_not_called() + + vector._client.indices.exists.return_value = True + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None]) + vector.delete_by_ids(["doc-1", "doc-2"]) + opensearch_module.helpers.bulk.assert_called_once() + + opensearch_module.helpers.bulk.reset_mock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"]) + opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError( + [{"delete": {"status": 404, "_id": "es-404"}}] + ) + vector.delete_by_ids(["doc-404"]) + assert opensearch_module.helpers.bulk.call_count == 1 + + opensearch_module.helpers.bulk.side_effect = None + + +def test_delete_and_text_exists(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True) + + vector._client.get.return_value = {"_id": "id-1"} + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("not found") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validates_and_builds_documents(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + + with pytest.raises(ValueError, match="query_vector should be a list"): + vector.search_by_vector("not-a-list") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, 1]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-1", + opensearch_module.Field.METADATA_KEY: None, + }, + "_score": 0.9, + }, + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-2", + opensearch_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + "_score": 0.1, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"]) + query = vector._client.search.call_args.kwargs["body"] + assert "script_score" in query["query"] + + +def test_search_by_vector_reraises_client_error(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.search_by_vector([0.1, 0.2]) + + +def test_search_by_full_text_and_filters(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.METADATA_KEY: {"doc_id": "1"}, + opensearch_module.Field.VECTOR: [0.1], + opensearch_module.Field.CONTENT_KEY: "matched text", + } + }, + ] + } + } + + docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "matched text" + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}] + + +def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock()) + + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1)) + vector._client.indices.create.reset_mock() + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_called_once() + index_body = vector._client.indices.create.call_args.kwargs["body"] + assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2 + opensearch_module.redis_client.set.assert_called() + + +def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch): + factory = opensearch_module.OpenSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None) + + with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py new file mode 100644 index 0000000000..817a7d342b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py @@ -0,0 +1,375 @@ +import array +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_oracle_modules(): + jieba = types.ModuleType("jieba") + jieba_posseg = types.ModuleType("jieba.posseg") + jieba_posseg.cut = MagicMock(return_value=[]) + jieba.posseg = jieba_posseg + + oracledb = types.ModuleType("oracledb") + oracledb_connection = types.ModuleType("oracledb.connection") + + class Connection: + pass + + oracledb_connection.Connection = Connection + oracledb.defaults = SimpleNamespace(fetch_lobs=True) + oracledb.DB_TYPE_VECTOR = object() + oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock())) + oracledb.connect = MagicMock() + + return { + "jieba": jieba, + "jieba.posseg": jieba_posseg, + "oracledb": oracledb, + "oracledb.connection": oracledb_connection, + } + + +def _connection_with_cursor(cursor): + cursor_ctx = MagicMock() + cursor_ctx.__enter__.return_value = cursor + cursor_ctx.__exit__.return_value = None + + connection = MagicMock() + connection.__enter__.return_value = connection + connection.__exit__.return_value = None + connection.cursor.return_value = cursor_ctx + return connection + + +@pytest.fixture +def oracle_module(monkeypatch): + for name, module in _build_fake_oracle_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.oracle.oraclevector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "user": "system", + "password": "oracle", + "dsn": "oracle:1521/freepdb1", + "is_autonomous": False, + } + values.update(overrides) + return module.OracleVectorConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("user", "", "config ORACLE_USER is required"), + ("password", "", "config ORACLE_PASSWORD is required"), + ("dsn", "", "config ORACLE_DSN is required"), + ], +) +def test_oracle_config_validation_required_fields(oracle_module, field, value, message): + values = _config(oracle_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oracle_module.OracleVectorConfig.model_validate(values) + + +def test_oracle_config_validation_autonomous_requirements(oracle_module): + with pytest.raises(ValidationError, match="config_dir is required"): + oracle_module.OracleVectorConfig.model_validate( + {"user": "u", "password": "p", "dsn": "d", "is_autonomous": True} + ) + + +def test_init_and_get_type(oracle_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool)) + vector = oracle_module.OracleVector("collection_1", _config(oracle_module)) + + assert vector.get_type() == "oracle" + assert vector.table_name == "embedding_collection_1" + assert vector.pool is pool + + +def test_numpy_converters_and_type_handlers(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + + in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64)) + in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32)) + in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8)) + assert in_float64.typecode == "d" + assert in_float32.typecode == "f" + assert in_int8.typecode == "b" + + cursor = MagicMock() + vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2) + cursor.var.assert_called_with( + oracle_module.oracledb.DB_TYPE_VECTOR, + arraysize=2, + inconverter=vector.numpy_converter_in, + ) + + metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR) + cursor.arraysize = 3 + vector.output_type_handler(cursor, metadata) + cursor.var.assert_called_with( + metadata.type_code, + arraysize=3, + outconverter=vector.numpy_converter_out, + ) + + out_int8 = vector.numpy_converter_out(array.array("b", [1])) + assert out_int8.dtype == numpy.int8 + out_float32 = vector.numpy_converter_out(array.array("f", [1.0])) + assert out_float32.dtype == numpy.float32 + out_float64 = vector.numpy_converter_out(array.array("d", [1.0])) + assert out_float64.dtype == numpy.float64 + + +def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch): + connect = MagicMock(return_value="connection") + monkeypatch.setattr(oracle_module.oracledb, "connect", connect) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.config = _config(oracle_module) + assert vector._get_connection() == "connection" + connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1") + + vector.config = _config( + oracle_module, + is_autonomous=True, + config_dir="/wallet", + wallet_location="/wallet", + wallet_password="pw", + ) + vector._get_connection() + assert connect.call_args.kwargs["config_dir"] == "/wallet" + assert connect.call_args.kwargs["wallet_location"] == "/wallet" + + +def test_create_delegates_collection_and_insert(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.execute.side_effect = [None, RuntimeError("insert failed")] + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + assert cursor.execute.call_count == 2 + assert connection.commit.call_count >= 1 + connection.close.assert_called() + + +def test_text_exists_and_get_by_ids(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.pool = MagicMock() + + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + vector.pool.release.assert_called_once() + assert vector.get_by_ids([]) == [] + + +def test_delete_methods(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + vector.delete_by_ids([]) + vector._get_connection.assert_not_called() + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql) + assert any("JSON_VALUE(meta" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_with_threshold_and_filter(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)]) + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=0, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "fetch first 4 rows only" in sql + assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql + + +def _fake_nltk_module(*, missing_data=False): + nltk = types.ModuleType("nltk") + nltk_corpus = types.ModuleType("nltk.corpus") + + class _Data: + @staticmethod + def find(_path): + if missing_data: + raise LookupError("missing") + return True + + nltk.data = _Data() + nltk.word_tokenize = lambda text: text.split() + nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"]) + return nltk, nltk_corpus + + +def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("张", "nr"), ("三", "nr"), ("。", "x")])) + zh_docs = vector.search_by_full_text("张三", top_k=2) + assert len(zh_docs) == 1 + zh_params = cursor.execute.call_args.args[1] + assert zh_params["kk"] == "张三" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=False) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])]) + en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"]) + assert len(en_docs) == 1 + en_sql = cursor.execute.call_args.args[0] + en_params = cursor.execute.call_args.args[1] + assert "fetch first 5 rows only" in en_sql + assert "doc_id_0" in en_params + + +def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector._get_connection = MagicMock() + + empty_result = vector.search_by_full_text("") + assert empty_result[0].page_content == "" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=True) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + with pytest.raises(LookupError, match="required NLTK data package"): + vector.search_by_full_text("english query") + + +def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock()) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_not_called() + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(2) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql) + oracle_module.redis_client.set.assert_called_once() + + +def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch): + factory = oracle_module.OracleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False) + + with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 0000000000..1aec81b8ac --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,317 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_pgvecto_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSessionContext: + def __init__(self, calls, execute_results=None): + self.calls = calls + self.execute_results = execute_results or [] + self.execute = MagicMock(side_effect=self._execute_side_effect) + self.commit = MagicMock() + + def _execute_side_effect(self, *args, **kwargs): + self.calls.append((args, kwargs)) + if self.execute_results: + return self.execute_results.pop(0) + return MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(calls, execute_results=None): + def _session(_client): + return _FakeSessionContext(calls=calls, execute_results=execute_results) + + return _session + + +@pytest.fixture +def pgvecto_module(monkeypatch): + for name, module in _build_fake_pgvecto_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module + import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + + return importlib.reload(module), importlib.reload(collection_module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "postgres", + } + values.update(overrides) + return module.PgvectoRSConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config PGVECTO_RS_HOST is required"), + ("port", 0, "config PGVECTO_RS_PORT is required"), + ("user", "", "config PGVECTO_RS_USER is required"), + ("password", "", "config PGVECTO_RS_PASSWORD is required"), + ("database", "", "config PGVECTO_RS_DATABASE is required"), + ], +) +def test_pgvecto_config_validation(pgvecto_module, field, value, message): + module, _ = pgvecto_module + values = _config(module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + module.PgvectoRSConfig.model_validate(values) + + +def test_collection_base_has_expected_annotations(pgvecto_module): + _, collection_module = pgvecto_module + annotations = collection_module.CollectionORM.__annotations__ + assert {"id", "text", "meta", "vector"} <= set(annotations) + + +def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == module.VectorType.PGVECTO_RS + module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres") + assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls) + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(module.redis_client, "set", MagicMock()) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection(3) + assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None)) + vector.create_collection(3) + assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls) + module.redis_client.set.assert_called() + + +def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + runtime_calls = [] + execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] + + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + + class _InsertBuilder: + def __init__(self, table): + self.table = table + + def values(self, **kwargs): + return ("insert", kwargs) + + monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table)) + monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["uuid-1", "uuid-2"] + assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0]) + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("row-id-1",)]), + MagicMock(), + ], + ), + ) + vector.delete_by_ids(["doc-1"]) + assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + vector.delete() + assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) + + +def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + runtime_calls = [] + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("id-1",)]), + SimpleNamespace(fetchall=lambda: []), + ], + ), + ) + assert vector.text_exists("doc-1") is True + assert vector.text_exists("doc-1") is False + + class _DistanceExpr: + def label(self, _name): + return self + + class _VectorColumn: + def op(self, _operator, return_type=None): + def _call(_query_vector): + return _DistanceExpr() + + return _call + + class _MetaFilter: + def in_(self, values): + return ("in", values) + + class _MetaColumn: + def __getitem__(self, _item): + return _MetaFilter() + + class _Stmt: + def __init__(self): + self.where_called = False + + def limit(self, _value): + return self + + def order_by(self, _value): + return self + + def where(self, _value): + self.where_called = True + return self + + stmt = _Stmt() + monkeypatch.setattr(module, "select", lambda *_args: stmt) + + vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn()) + rows = [ + (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), + (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), + ] + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert stmt.where_called is True + assert vector.search_by_full_text("hello") == [] + + +def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + factory = module.PGVectoRSFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432) + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres") + + embeddings = MagicMock() + embeddings.embed_query.return_value = [0.1, 0.2, 0.3] + + with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py index 4998a9858f..7505262eb7 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -1,16 +1,19 @@ -import unittest +from contextlib import contextmanager +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module from core.rag.datasource.vdb.pgvector.pgvector import ( PGVector, PGVectorConfig, ) +from core.rag.models.document import Document -class TestPGVector(unittest.TestCase): - def setUp(self): +class TestPGVector: + def setup_method(self, method): self.config = PGVectorConfig( host="localhost", port=5432, @@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override): PGVectorConfig(**config) -if __name__ == "__main__": - unittest.main() +def test_create_delegates_collection_creation_and_insert(): + vector = PGVector.__new__(PGVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["doc-a"]) + docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["doc-a"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid") + execute_values = MagicMock() + monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values) + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + execute_values.assert_called_once() + + +def test_text_get_and_delete_methods(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + class _UndefinedTableError(Exception): + pass + + monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError) + cursor.execute.side_effect = _UndefinedTableError("missing") + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + vector.delete_by_ids(["doc-1"]) + + +def test_search_by_vector_supports_filter_and_threshold(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "meta->>'document_id' in ('d-1')" in sql + + +def test_search_by_full_text_branches_for_bigm_and_standard(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + vector.pg_bigm = False + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.7) + standard_sql = cursor.execute.call_args.args[0] + assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql + + cursor.execute.reset_mock() + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)]) + vector.pg_bigm = True + vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"]) + assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0] + assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0] + + +def test_pgvector_factory_initializes_expected_collection_name(monkeypatch): + factory = pgvector_module.PGVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False) + + with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py new file mode 100644 index 0000000000..bd8df520ba --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py @@ -0,0 +1,269 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def vastbase_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VastbaseVectorConfig( + host="localhost", + port=5432, + user="dify", + password="secret", + database="dify", + min_connection=1, + max_connection=5, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config VASTBASE_HOST is required"), + ("port", 0, "config VASTBASE_PORT is required"), + ("user", "", "config VASTBASE_USER is required"), + ("password", "", "config VASTBASE_PASSWORD is required"), + ("database", "", "config VASTBASE_DATABASE is required"), + ("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"), + ("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"), + ], +) +def test_vastbase_config_validation(vastbase_module, field, value, message): + values = _config(vastbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + vastbase_module.VastbaseVectorConfig.model_validate(values) + + +def test_vastbase_config_rejects_invalid_connection_window(vastbase_module): + with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"): + vastbase_module.VastbaseVectorConfig.model_validate( + { + "host": "localhost", + "port": 5432, + "user": "dify", + "password": "secret", + "database": "dify", + "min_connection": 6, + "max_connection": 5, + } + ) + + +def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + + vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module)) + assert vector.get_type() == "vastbase" + assert vector.table_name == "embedding_collection_1" + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_and_add_texts(vastbase_module, monkeypatch): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + vector._create_collection = MagicMock() + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid") + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert ids == ["doc-a", "generated-uuid"] + vastbase_module.psycopg2.extras.execute_values.assert_called_once() + + vector.add_texts = MagicMock(return_value=["doc-a"]) + result = vector.create(docs, [[0.1], [0.2], [0.3]]) + vector._create_collection.assert_called_once_with(1) + assert result == ["doc-a"] + + +def test_text_get_delete_and_metadata_methods(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_ids([]) + vector.delete_by_ids(["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql) + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_and_full_text(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.8), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock()) + + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + cursor.execute.assert_not_called() + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(17000) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql) + + cursor.execute.reset_mock() + vector._create_collection(3) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql) + vastbase_module.redis_client.set.assert_called() + + +def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch): + factory = vastbase_module.VastbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5) + + with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py new file mode 100644 index 0000000000..0408506563 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py @@ -0,0 +1,328 @@ +import importlib +import os +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_qdrant_modules(): + qdrant_client = types.ModuleType("qdrant_client") + qdrant_http = types.ModuleType("qdrant_client.http") + qdrant_http_models = types.ModuleType("qdrant_client.http.models") + qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions") + qdrant_local_pkg = types.ModuleType("qdrant_client.local") + qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local") + + class UnexpectedResponseError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + class FilterSelector: + def __init__(self, filter): + self.filter = filter + + class HnswConfigDiff: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class TextIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class PointStruct: + def __init__(self, **kwargs): + self.id = kwargs["id"] + self.vector = kwargs["vector"] + self.payload = kwargs["payload"] + + class Filter: + def __init__(self, must=None): + self.must = must or [] + + class FieldCondition: + def __init__(self, key, match): + self.key = key + self.match = match + + class MatchValue: + def __init__(self, value): + self.value = value + + class MatchAny: + def __init__(self, any): + self.any = any + + class MatchText: + def __init__(self, text): + self.text = text + + class _Distance(UserDict): + def __getitem__(self, key): + return key + + class QdrantClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[])) + self.create_collection = MagicMock() + self.create_payload_index = MagicMock() + self.upsert = MagicMock() + self.delete = MagicMock() + self.delete_collection = MagicMock() + self.retrieve = MagicMock(return_value=[]) + self.search = MagicMock(return_value=[]) + self.scroll = MagicMock(return_value=([], None)) + + class QdrantLocal(QdrantClient): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._load = MagicMock() + + qdrant_client.QdrantClient = QdrantClient + qdrant_http_models.FilterSelector = FilterSelector + qdrant_http_models.HnswConfigDiff = HnswConfigDiff + qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD") + qdrant_http_models.TextIndexParams = TextIndexParams + qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT") + qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL") + qdrant_http_models.VectorParams = VectorParams + qdrant_http_models.Distance = _Distance() + qdrant_http_models.PointStruct = PointStruct + qdrant_http_models.Filter = Filter + qdrant_http_models.FieldCondition = FieldCondition + qdrant_http_models.MatchValue = MatchValue + qdrant_http_models.MatchAny = MatchAny + qdrant_http_models.MatchText = MatchText + qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError + + qdrant_http.models = qdrant_http_models + qdrant_local_mod.QdrantLocal = QdrantLocal + qdrant_local_pkg.qdrant_local = qdrant_local_mod + + return { + "qdrant_client": qdrant_client, + "qdrant_client.http": qdrant_http, + "qdrant_client.http.models": qdrant_http_models, + "qdrant_client.http.exceptions": qdrant_http_exceptions, + "qdrant_client.local": qdrant_local_pkg, + "qdrant_client.local.qdrant_local": qdrant_local_mod, + } + + +@pytest.fixture +def qdrant_module(monkeypatch): + for name, module in _build_fake_qdrant_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.qdrant.qdrant_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "endpoint": "http://localhost:6333", + "api_key": "api-key", + "timeout": 20, + "root_path": "/tmp", + "grpc_port": 6334, + "prefer_grpc": False, + "replication_factor": 1, + "write_consistency_factor": 1, + } + values.update(overrides) + return module.QdrantConfig.model_validate(values) + + +def test_qdrant_config_to_params(qdrant_module): + url_params = _config(qdrant_module).to_qdrant_params().model_dump() + assert url_params["url"] == "http://localhost:6333" + assert url_params["verify"] is False + + path_config = _config(qdrant_module, endpoint="path:storage") + assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage") + + with pytest.raises(ValueError, match="Root path is not set"): + _config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params() + + +def test_init_and_basic_behaviour(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.get_type() == qdrant_module.VectorType.QDRANT + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1" + + docs = [Document(page_content="a", metadata={"doc_id": "a"})] + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with("collection_1", 1) + vector.add_texts.assert_called_once() + + +def test_create_collection_and_add_texts(qdrant_module, monkeypatch): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.get_collections.return_value = SimpleNamespace(collections=[]) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_called_once() + assert vector._client.create_payload_index.call_count == 4 + qdrant_module.redis_client.set.assert_called_once() + + # add_texts and generated batches + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.upsert.call_count == 1 + + payloads = qdrant_module.QdrantVector._build_payloads( + ["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + assert payloads[0]["group_id"] == "g1" + with pytest.raises(ValueError, match="At least one of the texts is None"): + qdrant_module.QdrantVector._build_payloads( + [None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + + +def test_delete_and_exists_paths(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete() + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete() + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = None + + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")]) + assert vector.text_exists("id-1") is False + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]) + vector._client.retrieve.return_value = [{"id": "id-1"}] + assert vector.text_exists("id-1") is True + + +def test_search_and_helper_methods(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.search_by_vector([0.1], score_threshold=1.0) == [] + + vector._client.search.return_value = [ + SimpleNamespace(payload=None, score=0.9, vector=[0.1]), + SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]), + ] + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + # full text search: keyword split, dedup and top_k limit + scroll_results = [ + ( + [ + SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]), + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ( + [ + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ] + vector._client.scroll.side_effect = scroll_results + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert vector.search_by_full_text(" ", top_k=2) == [] + + local_client = qdrant_module.QdrantLocal() + vector._client = local_client + vector._reload_if_needed() + local_client._load.assert_called_once() + + doc = vector._document_from_scored_point( + SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]), + "page_content", + "metadata", + ) + assert doc.page_content == "doc" + + +def test_qdrant_factory_paths(qdrant_module, monkeypatch): + factory = qdrant_module.QdrantVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + collection_binding_id=None, + index_struct_dict=None, + index_struct=None, + ) + monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root"))) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1) + + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset.index_struct is not None + + # collection binding lookup path + dataset.collection_binding_id = "binding-1" + dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}} + monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + qdrant_module.db.session.scalars = MagicMock( + return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION")) + ) + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION" + + qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None)) + with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py new file mode 100644 index 0000000000..ca8cd5e514 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -0,0 +1,303 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_relyt_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSession: + def __init__(self, execute_result=None): + self.execute_result = execute_result or MagicMock(fetchall=lambda: []) + self.execute = MagicMock(return_value=self.execute_result) + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +@pytest.fixture +def relyt_module(monkeypatch): + for name, module in _build_fake_relyt_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.relyt.relyt_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "relyt", + } + values.update(overrides) + return module.RelytConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config RELYT_HOST is required"), + ("port", 0, "config RELYT_PORT is required"), + ("user", "", "config RELYT_USER is required"), + ("password", "", "config RELYT_PASSWORD is required"), + ("database", "", "config RELYT_DATABASE is required"), + ], +) +def test_relyt_config_validation(relyt_module, field, value, message): + values = _config(relyt_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + relyt_module.RelytConfig.model_validate(values) + + +def test_init_get_type_and_create_delegate(relyt_module, monkeypatch): + engine = MagicMock() + monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine)) + vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1") + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == relyt_module.VectorType.RELYT + assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt" + assert vector.embedding_dimension == 2 + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock()) + + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + session.execute.assert_not_called() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] + assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) + assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql) + assert any("CREATE INDEX" in sql for sql in executed_sql) + relyt_module.redis_client.set.assert_called_once() + + +def test_add_texts_and_metadata_queries(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector._group_id = "group-1" + vector.client = MagicMock() + + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + + monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "d-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert conn.execute.call_count >= 1 + first_insert_values = conn.execute.call_args.args[0].compile().params + assert "group_id" in str(first_insert_values) + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"] + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +# 1. delete_by_uuids: success and connect error +def test_delete_by_uuids_success_and_connect_error(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + with pytest.raises(ValueError, match="No ids provided"): + vector.delete_by_uuids(None) + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + assert vector.delete_by_uuids(["id-1"]) is True + vector.client.connect.side_effect = RuntimeError("boom") + assert vector.delete_by_uuids(["id-1"]) is False + + +# 2. delete_by_metadata_field calls delete_by_uuids +def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_uuids.assert_called_once_with(["id-1"]) + + +# 3. delete_by_ids translates to uuids +def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_ids(["doc-1", "doc-2"]) + vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"]) + + +# 4. text_exists True +def test_text_exists_true(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is True + + +# 5. text_exists False +def test_text_exists_false(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is False + + +# 6. similarity_search_with_score_by_vector returns Documents and scores +def test_similarity_search_with_score_by_vector(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + result_rows = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8), + ] + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = result_rows + vector.client.connect.return_value = conn + similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]}) + assert len(similarities) == 2 + assert similarities[0][0].page_content == "doc-a" + + +# 7. search_by_vector filters by score and ids +def test_search_by_vector_filters_by_score_and_ids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.similarity_search_with_score_by_vector = MagicMock( + return_value=[ + (Document(page_content="a", metadata={"doc_id": "1"}), 0.1), + (Document(page_content="b", metadata={}), 0.9), + ] + ) + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert vector.search_by_full_text("query") == [] + + +# 8. delete commits session +def test_delete_commits_session(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete() + session.commit.assert_called_once() + + +def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): + factory = relyt_module.RelytVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432) + monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt") + + with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py new file mode 100644 index 0000000000..e3b6676d9b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py @@ -0,0 +1,316 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_tablestore_module(): + tablestore = types.ModuleType("tablestore") + + class _BatchGetRowRequest: + def __init__(self): + self.items = [] + + def add(self, item): + self.items.append(item) + + class _TableInBatchGetRowItem: + def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver): + self.table_name = table_name + self.rows_to_get = rows_to_get + self.columns_to_get = columns_to_get + + class _Row: + def __init__(self, primary_key, attribute_columns=None): + self.primary_key = primary_key + self.attribute_columns = attribute_columns or [] + + class _Client: + def __init__(self, *_args): + self.list_table = MagicMock(return_value=[]) + self.create_table = MagicMock() + self.list_search_index = MagicMock(return_value=[]) + self.create_search_index = MagicMock() + self.delete_search_index = MagicMock() + self.delete_table = MagicMock() + self.put_row = MagicMock() + self.delete_row = MagicMock() + self.get_row = MagicMock(return_value=(None, None, None)) + self.batch_get_row = MagicMock() + self.search = MagicMock() + + tablestore.OTSClient = _Client + tablestore.BatchGetRowRequest = _BatchGetRowRequest + tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem + tablestore.Row = _Row + tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema) + tablestore.TableOptions = lambda: ("table_options",) + tablestore.CapacityUnit = lambda read, write: ("capacity", read, write) + tablestore.ReservedThroughput = lambda cap: ("reserved", cap) + tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs) + tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs) + tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas) + tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs) + tablestore.TermQuery = lambda key, value: ("term_query", key, value) + tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs) + tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.TermsQuery = lambda key, values: ("terms_query", key, values) + tablestore.Sort = lambda **kwargs: ("sort", kwargs) + tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs) + tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs) + + tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD") + tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD") + tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32") + tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE") + tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX") + tablestore.SortOrder = SimpleNamespace(DESC="DESC") + return tablestore + + +@pytest.fixture +def tablestore_module(monkeypatch): + fake_module = _build_fake_tablestore_module() + monkeypatch.setitem(sys.modules, "tablestore", fake_module) + + import core.rag.datasource.vdb.tablestore.tablestore_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "access_key_id": "ak", + "access_key_secret": "sk", + "instance_name": "instance", + "endpoint": "endpoint", + "normalize_full_text_bm25_score": False, + } + values.update(overrides) + return module.TableStoreConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("access_key_id", "", "config ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config ACCESS_KEY_SECRET is required"), + ("instance_name", "", "config INSTANCE_NAME is required"), + ("endpoint", "", "config ENDPOINT is required"), + ], +) +def test_tablestore_config_validation(tablestore_module, field, value, message): + values = _config(tablestore_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + tablestore_module.TableStoreConfig.model_validate(values) + + +def test_init_and_basic_delegation(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + assert vector.get_type() == tablestore_module.VectorType.TABLESTORE + assert vector._table_name == "collection_1" + assert vector._index_name == "collection_1_idx" + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]]) + + vector.create_collection([[0.1, 0.2]]) + assert vector._create_collection.call_count == 2 + + +def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + # get_by_ids + ok_item = SimpleNamespace( + is_ok=True, + row=SimpleNamespace( + attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)] + ), + ) + fail_item = SimpleNamespace(is_ok=False, row=None) + batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item]) + vector._tablestore_client.batch_get_row.return_value = batch_resp + docs = vector.get_by_ids(["id-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + + # text_exists + vector._tablestore_client.get_row.return_value = (None, object(), None) + assert vector.text_exists("id-1") is True + vector._tablestore_client.get_row.return_value = (None, None, None) + assert vector.text_exists("id-1") is False + + # delete wrappers + vector._delete_row = MagicMock() + vector.delete_by_ids([]) + vector._delete_row.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._delete_row.call_count == 2 + + vector._search_by_metadata = MagicMock(return_value=["id-a"]) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"] + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-a"]) + + vector._search_by_vector = MagicMock(return_value=["vec-doc"]) + vector._search_by_full_text = MagicMock(return_value=["fts-doc"]) + assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"] + assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"] + + vector._delete_table_if_exist = MagicMock() + vector.delete() + vector._delete_table_if_exist.assert_called_once() + + +def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table_if_not_exist = MagicMock() + vector._create_search_index_if_not_exist = MagicMock() + vector._create_collection(3) + vector._create_table_if_not_exist.assert_not_called() + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(3) + vector._create_table_if_not_exist.assert_called_once() + vector._create_search_index_if_not_exist.assert_called_once_with(3) + tablestore_module.redis_client.set.assert_called_once() + + vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module)) + vector._tablestore_client.list_table.return_value = ["collection_2"] + assert vector._create_table_if_not_exist() is None + vector._tablestore_client.list_table.return_value = [] + vector._create_table_if_not_exist() + vector._tablestore_client.create_table.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")] + assert vector._create_search_index_if_not_exist(3) is None + vector._tablestore_client.list_search_index.return_value = [] + vector._create_search_index_if_not_exist(3) + vector._tablestore_client.create_search_index.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")] + vector._delete_table_if_exist() + assert vector._tablestore_client.delete_search_index.call_count == 2 + vector._tablestore_client.delete_table.assert_called_once_with("collection_2") + + vector._delete_search_index() + vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx") + + +def test_write_row_and_search_helpers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + vector._write_row( + "id-1", + { + "page_content": "hello", + "vector": [0.1, 0.2], + "metadata": {"doc_id": "d-1", "document_id": "doc-1"}, + }, + ) + put_row_call = vector._tablestore_client.put_row.call_args + assert put_row_call.args[0] == "collection_1" + attrs = put_row_call.args[1].attribute_columns + assert any(item[0] == "metadata_tags" for item in attrs) + + vector._delete_row("id-1") + vector._tablestore_client.delete_row.assert_called_once() + + # metadata search pagination + first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next") + second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"") + vector._tablestore_client.search.side_effect = [first_page, second_page] + ids = vector._search_by_metadata("document_id", "doc-1") + assert ids == ["row-1", "row-2"] + vector._tablestore_client.search.side_effect = None + + # vector search + hit1 = SimpleNamespace( + score=0.9, + row=( + None, + [("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))], + ), + ) + hit2 = SimpleNamespace( + score=0.2, + row=( + None, + [("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))], + ), + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2]) + docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0) + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0 + + # full text search with and without normalized score filter + vector._normalize_full_text_bm25_score = True + hit3 = SimpleNamespace( + score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))]) + ) + hit4 = SimpleNamespace( + score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))]) + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4]) + docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2) + assert len(docs) == 1 + assert "score" in docs[0].metadata + + vector._normalize_full_text_bm25_score = False + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3]) + docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0) + assert len(docs) == 1 + assert "score" not in docs[0].metadata + + +def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch): + factory = tablestore_module.TableStoreVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True) + + with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py new file mode 100644 index 0000000000..d8f35a6019 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py @@ -0,0 +1,309 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_tencent_modules(): + tcvdb_text = types.ModuleType("tcvdb_text") + tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder") + tcvectordb = types.ModuleType("tcvectordb") + tcvectordb_model = types.ModuleType("tcvectordb.model") + tcvectordb_document = types.ModuleType("tcvectordb.model.document") + tcvectordb_index = types.ModuleType("tcvectordb.model.index") + tcvectordb_enum = types.ModuleType("tcvectordb.model.enum") + + class _BM25Encoder: + def encode_texts(self, text): + return {"encoded_text": text} + + def encode_queries(self, query): + return {"encoded_query": query} + + @classmethod + def default(cls, _lang): + return cls() + + class VectorDBError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + class RPCVectorDBClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_database_if_not_exists = MagicMock() + self.exists_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[])) + self.create_collection = MagicMock() + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.search = MagicMock(return_value=[]) + self.hybrid_search = MagicMock(return_value=[]) + self.drop_collection = MagicMock() + + class _Document: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class _HNSWSearchParams: + def __init__(self, ef): + self.ef = ef + + class _AnnSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _KeywordSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _WeightedRerank: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Filter: + @staticmethod + def in_(field, values): + return ("in", field, values) + + def __init__(self, condition): + self.condition = condition + + _Filter.In = staticmethod(_Filter.in_) + + class _HNSWParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _FilterIndex: + def __init__(self, *args): + self.args = args + + class _VectorIndex: + def __init__(self, *args): + self.args = args + + class _SparseIndex: + def __init__(self, **kwargs): + self.kwargs = kwargs + + tcvectordb_enum.IndexType = SimpleNamespace( + __members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"}, + PRIMARY_KEY="PRIMARY_KEY", + FILTER="FILTER", + SPARSE_INVERTED="SPARSE", + ) + tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP") + tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector") + + tcvectordb_document.Document = _Document + tcvectordb_document.HNSWSearchParams = _HNSWSearchParams + tcvectordb_document.AnnSearch = _AnnSearch + tcvectordb_document.Filter = _Filter + tcvectordb_document.KeywordSearch = _KeywordSearch + tcvectordb_document.WeightedRerank = _WeightedRerank + + tcvectordb_index.HNSWParams = _HNSWParams + tcvectordb_index.FilterIndex = _FilterIndex + tcvectordb_index.VectorIndex = _VectorIndex + tcvectordb_index.SparseIndex = _SparseIndex + + tcvdb_text_encoder.BM25Encoder = _BM25Encoder + + tcvectordb_model.document = tcvectordb_document + tcvectordb_model.enum = tcvectordb_enum + tcvectordb_model.index = tcvectordb_index + + tcvectordb.RPCVectorDBClient = RPCVectorDBClient + tcvectordb.VectorDBException = VectorDBError + + return { + "tcvdb_text": tcvdb_text, + "tcvdb_text.encoder": tcvdb_text_encoder, + "tcvectordb": tcvectordb, + "tcvectordb.model": tcvectordb_model, + "tcvectordb.model.document": tcvectordb_document, + "tcvectordb.model.index": tcvectordb_index, + "tcvectordb.model.enum": tcvectordb_enum, + } + + +@pytest.fixture +def tencent_module(monkeypatch): + for name, module in _build_fake_tencent_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.tencent.tencent_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "url": "http://vdb.local", + "api_key": "api-key", + "timeout": 30, + "username": "user", + "database": "db", + "index_type": "HNSW", + "metric_type": "IP", + "shard": 1, + "replicas": 2, + "max_upsert_batch_size": 2, + "enable_hybrid_search": False, + } + values.update(overrides) + return module.TencentConfig.model_validate(values) + + +def test_config_and_init_paths(tencent_module): + config = _config(tencent_module) + assert config.to_tencent_params()["url"] == "http://vdb.local" + + vector = tencent_module.TencentVector("collection_1", config) + assert vector.get_type() == tencent_module.VectorType.TENCENT + assert vector._client.kwargs["key"] == "api-key" + + vector._client.exists_collection.return_value = True + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)] + ) + vector._client_config.enable_hybrid_search = True + vector._load_collection() + assert vector._enable_hybrid_search is True + assert vector._dimension == 768 + + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=512)] + ) + vector._load_collection() + assert vector._enable_hybrid_search is False + + +def test_create_collection_branches(tencent_module, monkeypatch): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.exists_collection.return_value = True + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + vector._client.exists_collection.return_value = False + vector._client_config.index_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_collection(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_collection(3) + + vector._client_config.metric_type = "IP" + vector._client.create_collection.side_effect = [ + tencent_module.VectorDBException("fieldType:json unsupported"), + None, + ] + vector._enable_hybrid_search = True + vector._create_collection(3) + assert vector._client.create_collection.call_count == 2 + tencent_module.redis_client.set.assert_called_once() + vector._client.create_collection.side_effect = None + + +def test_create_add_delete_and_search_behaviour(tencent_module): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True)) + vector._create_collection = MagicMock() + docs = [ + Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}), + Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}), + Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + vector.create(docs, embeddings) + vector._create_collection.assert_called_once_with(1) + + vector._client.upsert.reset_mock() + vector.add_texts(docs, embeddings) + assert vector._client.upsert.call_count == 2 + first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"] + assert "sparse_vector" in first_docs[0].__dict__ + + vector._client.query.return_value = [{"id": "a"}] + assert vector.text_exists("a") is True + vector._client.query.return_value = [] + assert vector.text_exists("a") is False + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["a", "b", "c"]) + assert vector._client.delete.call_count == 2 + vector.delete_by_metadata_field("document_id", "doc-a") + assert vector._client.delete.call_count >= 3 + + vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]] + vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(vec_docs) == 1 + assert vec_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._enable_hybrid_search = False + assert vector.search_by_full_text("query") == [] + vector._enable_hybrid_search = True + vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]] + fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(fts_docs) == 1 + + # _get_search_res handles old string metadata format + compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5) + assert len(compat_docs) == 1 + assert compat_docs[0].metadata["score"] == pytest.approx(0.8) + + vector._has_collection = MagicMock(return_value=True) + vector.delete() + vector._client.drop_collection.assert_called_once() + + +def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch): + factory = tencent_module.TencentVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True) + + with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py new file mode 100644 index 0000000000..369cda39bf --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + + +class _DummyVector(BaseVector): + def __init__(self, collection_name: str, existing_ids: set[str] | None = None): + super().__init__(collection_name) + self._existing_ids = existing_ids or set() + + def get_type(self) -> str: + return "dummy" + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete_by_metadata_field(self, key: str, value: str): + return None + + def search_by_vector(self, query_vector: list[float], **kwargs): + return [] + + def search_by_full_text(self, query: str, **kwargs): + return [] + + def delete(self): + return None + + +@pytest.mark.parametrize( + ("base_method", "args"), + [ + (BaseVector.get_type, ()), + (BaseVector.create, ([], [])), + (BaseVector.add_texts, ([], [])), + (BaseVector.text_exists, ("doc-1",)), + (BaseVector.delete_by_ids, ([],)), + (BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.search_by_vector, ([0.1],)), + (BaseVector.search_by_full_text, ("query",)), + (BaseVector.delete, ()), + ], +) +def test_base_vector_default_methods_raise_not_implemented(base_method, args): + vector = _DummyVector("collection_1") + + with pytest.raises(NotImplementedError): + base_method(vector, *args) + + +def test_filter_duplicate_texts_removes_existing_docs(): + vector = _DummyVector("collection_1", existing_ids={"dup"}) + docs = [ + SimpleNamespace(page_content="keep-no-meta", metadata=None), + Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}), + Document(page_content="remove-dup", metadata={"doc_id": "dup"}), + Document(page_content="keep-unique", metadata={"doc_id": "unique"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + + assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"] + + +def test_get_uuids_and_collection_name_property(): + vector = _DummyVector("collection_1") + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + Document(page_content="c", metadata={"document_id": "d-1"}), + Document(page_content="d", metadata={"doc_id": "id-2"}), + ] + + assert vector._get_uuids(docs) == ["id-1", "id-2"] + assert vector.collection_name == "collection_1" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py new file mode 100644 index 0000000000..54ad6d330b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -0,0 +1,436 @@ +import base64 +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str): + fake_module = types.ModuleType(module_path) + fake_cls = type(class_name, (), {}) + setattr(fake_module, class_name, fake_cls) + monkeypatch.setitem(sys.modules, module_path, fake_module) + return fake_cls + + +@pytest.fixture +def vector_factory_module(): + import importlib + + import core.rag.datasource.vdb.vector_factory as module + + return importlib.reload(module) + + +def test_gen_index_struct_dict(vector_factory_module): + result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict( + vector_factory_module.VectorType.WEAVIATE, + "collection_1", + ) + + assert result == { + "type": vector_factory_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "collection_1"}, + } + + +@pytest.mark.parametrize( + ("vector_type", "module_path", "class_name"), + [ + ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ( + "ALIBABACLOUD_MYSQL", + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "AlibabaCloudMySQLVectorFactory", + ), + ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ( + "ELASTICSEARCH", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "ElasticSearchVectorFactory", + ), + ( + "ELASTICSEARCH_JA", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "ElasticSearchJaVectorFactory", + ), + ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ( + "OPENSEARCH", + "core.rag.datasource.vdb.opensearch.opensearch_vector", + "OpenSearchVectorFactory", + ), + ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ( + "TIDB_ON_QDRANT", + "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "TidbOnQdrantVectorFactory", + ), + ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ( + "HUAWEI_CLOUD", + "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "HuaweiCloudVectorFactory", + ), + ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ], +) +def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): + expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name) + + result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type)) + + assert result_cls is expected_cls + + +def test_get_vector_factory_unsupported(vector_factory_module): + with pytest.raises(ValueError, match="not supported"): + vector_factory_module.Vector.get_vector_factory("unknown") + + +def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): + dataset = SimpleNamespace(id="dataset-1") + + with ( + patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), + patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), + ): + default_vector = vector_factory_module.Vector(dataset) + custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) + + assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + assert custom_vector._attributes == ["doc_id"] + assert default_vector._embeddings == "embeddings" + assert default_vector._vector_processor == "processor" + + +def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): + calls = {"vector_type": None, "init_args": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + calls["init_args"] = (dataset, attributes, embeddings) + return "vector-processor" + + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1" + ) + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH + assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings") + + +def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch): + class _Expr: + def __eq__(self, _other): + return "expr" + + calls = {"vector_type": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + return "vector-processor" + + monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))), + ) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True) + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT + + +def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch): + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = [] + vector._embeddings = "embeddings" + + with pytest.raises(ValueError, match="Vector store must be specified"): + vector._init_vector() + + +def test_create_batches_texts_and_skips_empty_input(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)] + vector._embeddings.embed_documents.side_effect = [ + [[0.1] for _ in range(1000)], + [[0.2]], + ] + + vector.create(texts=docs, trace_id="trace-1") + + assert vector._embeddings.embed_documents.call_count == 2 + assert vector._vector_processor.create.call_count == 2 + assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1" + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create(texts=None) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch): + class _Field: + def in_(self, value): + return value + + def __eq__(self, value): + return value + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]] + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace( + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")]) + ) + ), + ) + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc")) + + docs = [ + Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}), + Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}), + ] + + vector.create_multimodal(file_documents=docs, request_id="r-1") + + file_base64 = base64.b64encode(b"abc").decode() + vector._embeddings.embed_multimodal_documents.assert_called_once_with( + [{"content": file_base64, "content_type": "image", "file_id": "f-1"}] + ) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], + embeddings=[[0.1, 0.2]], + request_id="r-1", + ) + + vector._embeddings.embed_multimodal_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create_multimodal(file_documents=None) + vector._embeddings.embed_multimodal_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_add_texts_with_optional_duplicate_check(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + vector._filter_duplicate_texts = MagicMock() + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + vector._filter_duplicate_texts.return_value = [docs[0]] + vector._embeddings.embed_documents.return_value = [[0.1]] + + vector.add_texts(docs, duplicate_check=True, flag=True) + + vector._filter_duplicate_texts.assert_called_once_with(docs) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True + ) + + vector._filter_duplicate_texts.reset_mock() + vector._vector_processor.create.reset_mock() + vector._embeddings.embed_documents.return_value = [[0.2], [0.3]] + + vector.add_texts(docs, duplicate_check=False) + + vector._filter_duplicate_texts.assert_not_called() + vector._vector_processor.create.assert_called_once() + + +def test_vector_delegation_methods(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_query.return_value = [0.1, 0.2] + vector._vector_processor = MagicMock() + vector._vector_processor.text_exists.return_value = True + vector._vector_processor.search_by_vector.return_value = ["vector-doc"] + vector._vector_processor.search_by_full_text.return_value = ["text-doc"] + + assert vector.text_exists("doc-1") is True + vector.delete_by_ids(["doc-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"] + assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"] + + vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"]) + vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1") + + +def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): + class _Field: + def __eq__(self, value): + return value + + upload_query = MagicMock() + upload_query.where.return_value = upload_query + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr( + vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query)) + ) + + upload_query.first.return_value = None + assert vector.search_by_file("file-1") == [] + + upload_query.first.return_value = SimpleNamespace(key="blob-key") + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) + vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] + vector._vector_processor.search_by_vector.return_value = ["hit"] + + result = vector.search_by_file("file-2", top_k=2) + + assert result == ["hit"] + payload = vector._embeddings.embed_multimodal_query.call_args.args[0] + assert payload["content_type"] == vector_factory_module.DocType.IMAGE + assert payload["file_id"] == "file-2" + + +def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch): + delete_mock = MagicMock() + redis_delete = MagicMock() + monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1") + + vector.delete() + + delete_mock.assert_called_once() + redis_delete.assert_called_once_with("vector_indexing_collection_1") + + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="") + redis_delete.reset_mock() + vector.delete() + redis_delete.assert_not_called() + + +def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch): + model_manager = MagicMock() + model_manager.get_model_instance.return_value = "model-instance" + + for_tenant_mock = MagicMock(return_value=model_manager) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + result = vector._get_embeddings() + + assert result == "cached-embedding" + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=vector_factory_module.ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + +def test_filter_duplicate_texts_and_getattr(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup") + + docs = [ + SimpleNamespace(page_content="no-meta", metadata=None), + Document(page_content="empty-doc-id", metadata={"doc_id": ""}), + Document(page_content="duplicate", metadata={"doc_id": "dup"}), + Document(page_content="unique", metadata={"doc_id": "ok"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"] + + class _Processor: + def ping(self): + return "pong" + + vector._vector_processor = _Processor() + assert vector.ping() == "pong" + + with pytest.raises(AttributeError): + _ = vector.unknown_method + + vector._vector_processor = None + with pytest.raises(AttributeError, match="vector_processor"): + _ = vector.another_missing diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 0000000000..951a920f3b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,443 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +@pytest.fixture +def tidb_module(): + import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + + return importlib.reload(module) + + +def _config(tidb_module): + return tidb_module.TiDBVectorConfig( + host="localhost", + port=4000, + user="root", + password="secret", + database="dify", + program_name="dify-app", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config TIDB_VECTOR_HOST is required"), + ("port", 0, "config TIDB_VECTOR_PORT is required"), + ("user", "", "config TIDB_VECTOR_USER is required"), + ("database", "", "config TIDB_VECTOR_DATABASE is required"), + ("program_name", "", "config APPLICATION_NAME is required"), + ], +) +def test_tidb_config_validation(tidb_module, field, value, message): + values = _config(tidb_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + tidb_module.TiDBVectorConfig.model_validate(values) + + +def test_init_get_type_and_distance_func(tidb_module, monkeypatch): + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine")) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2") + + assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR + assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify") + assert vector._dimension == 1536 + assert vector._get_distance_func() == "VEC_L2_DISTANCE" + + vector._distance_func = "cosine" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + vector._distance_func = "other" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + +def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch): + fake_tidb_vector = types.ModuleType("tidb_vector") + fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy") + + class _VectorType: + def __init__(self, dim): + self.dim = dim + + fake_tidb_sqlalchemy.VectorType = _VectorType + + monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector) + monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy) + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock())) + monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr( + tidb_module, + "Table", + lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns), + ) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module)) + table = vector._table(3) + + assert table.name == "collection_1" + column_names = [column.args[0] for column in table.columns] + assert tidb_module.Field.PRIMARY_KEY in column_names + assert tidb_module.Field.VECTOR in column_names + assert tidb_module.Field.TEXT_KEY in column_names + + +def test_create_calls_collection_and_add_texts(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + assert vector._dimension == 2 + + +def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + + tidb_module.Session = MagicMock() + + vector._create_collection(3) + + tidb_module.Session.assert_not_called() + tidb_module.redis_client.set.assert_not_called() + + +def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "l2" + + vector._create_collection(3) + + session.begin.assert_called_once() + sql = str(session.execute.call_args.args[0]) + assert "VECTOR(3)" in sql + assert "VEC_L2_DISTANCE" in sql + session.commit.assert_called_once() + tidb_module.redis_client.set.assert_called_once() + + +def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch): + class _InsertStmt: + def __init__(self, table): + self.table = table + + def values(self, rows): + return {"table": self.table, "rows": rows} + + monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table)) + + conn = MagicMock() + transaction = MagicMock() + transaction.__enter__.return_value = None + transaction.__exit__.return_value = None + conn.begin.return_value = transaction + + connection_ctx = MagicMock() + connection_ctx.__enter__.return_value = conn + connection_ctx.__exit__.return_value = None + + engine = MagicMock() + engine.connect.return_value = connection_ctx + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._engine = engine + vector._table = MagicMock(return_value="table") + + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)] + embeddings = [[float(i)] for i in range(501)] + + ids = vector.add_texts(docs, embeddings) + + assert ids[0] == "id-0" + assert len(ids) == 501 + assert conn.execute.call_count == 2 + + +@pytest.fixture +def tidb_vector_with_session(tidb_module, monkeypatch): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + return vector, session, tidb_module + + +# 1. search_by_full_text returns empty +def test_search_by_full_text_returns_empty(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + assert vector.search_by_full_text("query") == [] + + +# 2. text_exists returns True when ids found +def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + + +# 3. text_exists returns False when no ids +def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=None) + assert vector.text_exists("doc-1") is False + + +# 4. delete_by_ids delegates to _delete_by_ids when ids found +def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"]) + + +# 5. delete_by_ids skips when no ids found +def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-c"]) + vector._delete_by_ids.assert_not_called() + + +# 6. get_ids_by_metadata_field returns ids and returns None +def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + # Returns ids + session.execute.return_value.fetchall.return_value = [("id-1",)] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"] + # Returns None + session.execute.return_value.fetchall.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None + + +# 1. _delete_by_ids raises on None +def test__delete_by_ids_raises_on_none(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + with pytest.raises(ValueError, match="No ids provided"): + vector._delete_by_ids(None) + + +# 2. _delete_by_ids returns True and calls execute +def test__delete_by_ids_returns_true_and_calls_execute(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-1"]) is True + conn.execute.assert_called_once() + + +# 3. _delete_by_ids returns False on RuntimeError +def test__delete_by_ids_returns_false_on_runtime_error(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + conn.execute.side_effect = RuntimeError("delete failed") + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-2"]) is False + + +# 4. delete_by_metadata_field calls _delete_by_ids when ids found +def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-3") + vector._delete_by_ids.assert_called_once_with(["id-3"]) + + +# 5. delete_by_metadata_field does nothing when no ids +def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-4") + vector._delete_by_ids.assert_not_called() + + +# Test search_by_vector filters and scores +def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = [ + ('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2), + ('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4), + ] + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "cosine" + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.8) + assert docs[1].metadata["score"] == pytest.approx(0.6) + sql = str(session.execute.call_args.args[0]) + params = session.execute.call_args.kwargs["params"] + assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql + assert params["distance"] == pytest.approx(0.5) + assert params["top_k"] == 2 + session.commit.assert_not_called() + + +# Test delete drops table +def test_delete_drops_table(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = None + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector.delete() + drop_sql = str(session.execute.call_args.args[0]) + assert "DROP TABLE IF EXISTS collection_1" in drop_sql + session.commit.assert_called_once() + + +def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): + factory = tidb_module.TiDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000) + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify") + monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app") + + with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py new file mode 100644 index 0000000000..ac8a63a44b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,186 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_upstash_module(): + upstash_module = types.ModuleType("upstash_vector") + + class Vector: + def __init__(self, id, vector, metadata, data): + self.id = id + self.vector = vector + self.metadata = metadata + self.data = data + + class Index: + def __init__(self, url, token): + self.url = url + self.token = token + self.info = MagicMock(return_value=SimpleNamespace(dimension=8)) + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.reset = MagicMock() + + upstash_module.Vector = Vector + upstash_module.Index = Index + return upstash_module + + +@pytest.fixture +def upstash_module(monkeypatch): + # Remove patched modules if present + for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + if modname in sys.modules: + monkeypatch.delitem(sys.modules, modname, raising=False) + monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) + module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + return module + + +def _config(module): + return module.UpstashVectorConfig(url="https://upstash.example", token="token-123") + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("url", "", "Upstash URL is required"), + ("token", "", "Upstash Token is required"), + ], +) +def test_upstash_config_validation(upstash_module, field, value, message): + values = _config(upstash_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + upstash_module.UpstashVectorConfig.model_validate(values) + + +def test_init_get_type_and_dimension(upstash_module, monkeypatch): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + assert vector.get_type() == upstash_module.VectorType.UPSTASH + assert vector._table_name == "collection_1" + assert vector._get_index_dimension() == 8 + + vector.index.info.return_value = SimpleNamespace(dimension=None) + assert vector._get_index_dimension() == 1536 + + vector.index.info.return_value = None + assert vector._get_index_dimension() == 1536 + + monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid") + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.add_texts(docs, [[0.1, 0.2]]) + + vector.index.upsert.assert_called_once() + upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"] + assert upsert_vectors[0].id == "generated-uuid" + + +def test_create_text_exists_and_delete_by_ids(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + vector.get_ids_by_metadata_field.return_value = [] + assert vector.text_exists("doc-1") is False + + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]]) + vector._delete_by_ids = MagicMock() + vector.delete_by_ids(["doc-1", "doc-2", "doc-3"]) + vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"]) + + +def test_delete_helpers_and_search(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + vector._delete_by_ids([]) + vector.index.delete.assert_not_called() + vector._delete_by_ids(["a", "b"]) + vector.index.delete.assert_called_once_with(ids=["a", "b"]) + + vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")] + ids = vector.get_ids_by_metadata_field("doc_id", "doc-1") + assert ids == ["x-1", "x-2"] + query_kwargs = vector.index.query.call_args.kwargs + assert query_kwargs["top_k"] == 1000 + assert query_kwargs["filter"] == "doc_id = 'doc-1'" + + vector._delete_by_ids = MagicMock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + vector._delete_by_ids.assert_called_once_with(["x-1"]) + + vector._delete_by_ids.reset_mock() + vector.get_ids_by_metadata_field.return_value = [] + vector.delete_by_metadata_field("doc_id", "doc-2") + vector._delete_by_ids.assert_not_called() + + +def test_search_by_vector_filter_threshold_and_delete(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.index.query.return_value = [ + SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9), + SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3), + SimpleNamespace(metadata=None, data="text-3", score=0.99), + SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=3, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + search_kwargs = vector.index.query.call_args.kwargs + assert search_kwargs["top_k"] == 3 + assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')" + + assert vector.search_by_full_text("query") == [] + + vector.delete() + vector.index.reset.assert_called_once() + + +def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch): + factory = upstash_module.UpstashVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123") + + with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py new file mode 100644 index 0000000000..9da92af2d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py @@ -0,0 +1,310 @@ +import importlib +import json +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_vikingdb_modules(): + volcengine = types.ModuleType("volcengine") + volcengine.__path__ = [] + viking_db = types.ModuleType("volcengine.viking_db") + + class Data(UserDict): + def __init__(self, payload): + super().__init__(payload) + self.fields = payload + + class DistanceType: + L2 = "L2" + + class IndexType: + HNSW = "HNSW" + + class QuantType: + Float = "Float" + + class FieldType: + String = "string" + Text = "text" + Vector = "vector" + + class Field: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Collection: + def __init__(self): + self.upsert_data = MagicMock() + self.fetch_data = MagicMock(return_value=None) + self.delete_data = MagicMock() + + class _Index: + def __init__(self): + self.search = MagicMock(return_value=[]) + self.search_by_vector = MagicMock(return_value=[]) + + class VikingDBService: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_collection = MagicMock() + self.create_index = MagicMock() + self.drop_index = MagicMock() + self.drop_collection = MagicMock() + self._collection = _Collection() + self._index = _Index() + self.get_collection = MagicMock(return_value=self._collection) + self.get_index = MagicMock(return_value=self._index) + + viking_db.Data = Data + viking_db.DistanceType = DistanceType + viking_db.Field = Field + viking_db.FieldType = FieldType + viking_db.IndexType = IndexType + viking_db.QuantType = QuantType + viking_db.VectorIndexParams = VectorIndexParams + viking_db.VikingDBService = VikingDBService + + return {"volcengine": volcengine, "volcengine.viking_db": viking_db} + + +@pytest.fixture +def vikingdb_module(monkeypatch): + for name, module in _build_fake_vikingdb_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VikingDBConfig( + access_key="ak", + secret_key="sk", + host="host", + region="region", + scheme="https", + connection_timeout=10, + socket_timeout=20, + ) + + +def test_init_get_type_and_has_checks(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB + assert vector._index_name == "collection_1_idx" + + assert vector._has_collection() is True + assert vector._has_index() is True + + vector._client.get_collection.side_effect = RuntimeError("missing") + assert vector._has_collection() is False + vector._client.get_collection.side_effect = None + + vector._client.get_index.side_effect = RuntimeError("missing") + assert vector._has_index() is False + + +def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock()) + + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + vector._client.create_index.assert_not_called() + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None)) + vector._has_collection = MagicMock(return_value=False) + vector._has_index = MagicMock(return_value=False) + vector._create_collection(4) + + vector._client.create_collection.assert_called_once() + vector._client.create_index.assert_called_once() + vikingdb_module.redis_client.set.assert_called_once() + + +def test_create_and_add_texts(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module)) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}), + ] + vector.add_texts(docs, [[0.1], [0.2]]) + + vector._client.get_collection.assert_called() + upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0] + assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a" + assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2" + + +def test_text_exists_and_delete_operations(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"}) + assert vector.text_exists("id-1") is True + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace( + fields={"message": "data does not exist"} + ) + assert vector.text_exists("id-1") is False + + vector._client.get_collection.return_value.fetch_data.return_value = None + assert vector.text_exists("id-1") is False + + vector.delete_by_ids(["id-1"]) + vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-2"]) + + +def test_get_ids_and_search_helpers(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_index.return_value.search.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "x") == [] + + vector._client.get_index.return_value.search.return_value = [ + SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}), + SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}), + SimpleNamespace(id="c", fields={}), + ] + assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"] + + empty_docs = vector._get_search_res([], score_threshold=0.1) + assert empty_docs == [] + + results = [ + SimpleNamespace( + id="a", + score=0.3, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}), + }, + ), + SimpleNamespace( + id="b", + score=0.9, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}), + }, + ), + ] + + docs = vector._get_search_res(results, score_threshold=0.2) + assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"] + + vector._client.get_index.return_value.search_by_vector.return_value = results + filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"]) + assert len(filtered_docs) == 1 + assert filtered_docs[0].page_content == "doc-b" + assert vector.search_by_full_text("query") == [] + + +def test_delete_drops_index_and_collection_when_present(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._has_index = MagicMock(return_value=True) + vector._has_collection = MagicMock(return_value=True) + + vector.delete() + + vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx") + vector._client.drop_collection.assert_called_once_with("collection_1") + + vector._client.drop_index.reset_mock() + vector._client.drop_collection.reset_mock() + vector._has_index.return_value = False + vector._has_collection.return_value = False + vector.delete() + + vector._client.drop_index.assert_not_called() + vector._client.drop_collection.assert_not_called() + + +def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch): + factory = vikingdb_module.VikingDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls: + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10) + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20) + + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +@pytest.mark.parametrize( + ("field", "message"), + [ + ("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"), + ("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"), + ("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"), + ("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"), + ("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"), + ], +) +def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message): + factory = vikingdb_module.VikingDBVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None + ) + + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, field, None) + + with pytest.raises(ValueError, match=message): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py index 3bd656ba84..69d1833001 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py @@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in: - Full-text search result metadata (search_by_full_text) """ +import datetime +import json import unittest from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase): def tearDown(self): weaviate_vector_module._weaviate_client = None + def test_config_requires_endpoint(self): + with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): + WeaviateConfig(endpoint="") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" @@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase): ) return wv, mock_client + def test_shutdown_client_logs_debug_when_close_fails(self): + mock_client = MagicMock() + mock_client.close.side_effect = RuntimeError("close failed") + weaviate_vector_module._weaviate_client = mock_client + + with patch.object(weaviate_vector_module.logger, "debug") as mock_debug: + weaviate_vector_module._shutdown_weaviate_client() + + assert weaviate_vector_module._weaviate_client is None + mock_client.close.assert_called_once() + mock_debug.assert_called_once() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.return_value = True + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.side_effect = [False, True] + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + config = WeaviateConfig( + endpoint="https://weaviate.example.com", + grpc_endpoint="grpc.example.com:6000", + api_key="test-key", + batch_size=50, + ) + + client = wv._init_client(config) + + assert client is mock_client + assert mock_connect.call_args.kwargs == { + "http_host": "weaviate.example.com", + "http_port": 443, + "http_secure": True, + "grpc_host": "grpc.example.com", + "grpc_port": 6000, + "grpc_secure": False, + "auth_credentials": "auth-token", + "skip_init_checks": True, + } + mock_api_key.assert_called_once_with("test-key") + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_raises_when_database_not_ready(self, mock_connect): + mock_client = MagicMock() + mock_client.is_ready.return_value = False + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + + with pytest.raises(ConnectionError, match="Vector database is not ready"): + wv._init_client(self.config) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" @@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase): assert wv._collection_name == self.collection_name assert "doc_type" in wv._attributes + def test_get_type_and_to_index_struct(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + + assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE + assert wv.to_index_struct() == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": self.collection_name}, + } + + def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self): + dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1") + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv.get_collection_name(dataset) == "ExistingCollection_Node" + + def test_get_collection_name_generates_name_from_dataset_id(self): + dataset = SimpleNamespace(index_struct_dict=None, id="ds-2") + wv = WeaviateVector.__new__(WeaviateVector) + + with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"): + assert wv.get_collection_name(dataset) == "Generated_Node" + + def test_create_calls_collection_setup_then_add_texts(self): + doc = Document(page_content="hello", metadata={}) + wv = WeaviateVector.__new__(WeaviateVector) + wv._create_collection = MagicMock() + wv.add_texts = MagicMock() + + wv.create([doc], [[0.1, 0.2]]) + + wv._create_collection.assert_called_once() + wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._ensure_properties = MagicMock() + + wv._create_collection() + + wv._client.collections.exists.assert_not_called() + wv._ensure_properties.assert_not_called() + mock_redis.set.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_logs_and_reraises_errors(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock(return_value=False) + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = RuntimeError("create failed") + + with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: + with pytest.raises(RuntimeError, match="create failed"): + wv._create_collection() + + mock_exception.assert_called_once() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" @@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [SimpleNamespace(name="text")] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + add_calls = mock_col.config.add_property.call_args_list + added_names = [call.args[0].name for call in add_calls] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" @@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [] + mock_col.config.get.return_value = mock_cfg + mock_col.config.add_property.side_effect = RuntimeError("cannot add") + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: + wv._ensure_properties() + + assert mock_warning.call_count == 4 + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = { + "text": "fallback distance result", + "document_id": "doc-1", + "doc_id": "segment-1", + } + mock_obj.metadata = None + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.near_vector.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector( + query_vector=[0.2] * 3, + document_ids_filter=["doc-1"], + top_k=2, + score_threshold=-1, + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.0 + assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_vector(query_vector=[0.1] * 3) == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" @@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"} + mock_obj.vector = [0.3, 0.4] + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.bm25.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].vector == [0.3, 0.4] + assert mock_col.query.bm25.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_full_text(query="missing") == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" @@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_batch = MagicMock() + mock_batch.__enter__ = MagicMock(return_value=mock_batch) + mock_batch.__exit__ = MagicMock(return_value=False) + mock_col.batch.dynamic.return_value = mock_batch + + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + doc = Document(page_content="text", metadata={"created_at": created_at}) + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with ( + patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), + patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + ): + ids = wv.add_texts(documents=[doc], embeddings=[[]]) + + assert ids == ["fallback-uuid"] + call_kwargs = mock_batch.add_object.call_args + assert call_kwargs.kwargs["uuid"] == "fallback-uuid" + assert call_kwargs.kwargs["vector"] is None + assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat() + + def test_is_uuid_handles_invalid_values(self): + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True + assert wv._is_uuid("not-a-uuid") is False + + def test_delete_by_metadata_field_returns_when_collection_is_missing(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = False + + wv.delete_by_metadata_field("doc_id", "segment-1") + + wv._client.collections.use.assert_not_called() + + def test_delete_by_metadata_field_deletes_matching_objects(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + + wv.delete_by_metadata_field("doc_id", "segment-1") + + mock_col.data.delete_many.assert_called_once() + + def test_delete_removes_collection_when_present(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + + wv.delete() + + wv._client.collections.delete.assert_called_once_with(self.collection_name) + + def test_text_exists_handles_missing_and_present_documents(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()]) + + assert wv.text_exists("segment-1") is False + assert wv.text_exists("segment-1") is True + + def test_delete_by_ids_handles_missing_collections_and_404s(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None] + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + wv.delete_by_ids(["ignored"]) + wv.delete_by_ids(["missing-id", "ok-id"]) + + assert mock_col.data.delete_by_id.call_count == 2 + + def test_delete_by_ids_reraises_non_404_errors(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): + wv.delete_by_ids(["bad-id"]) + + def test_json_serializable_converts_datetime(self): + wv = WeaviateVector.__new__(WeaviateVector) + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + + assert wv._json_serializable(created_at) == created_at.isoformat() + assert wv._json_serializable("plain") == "plain" + class TestVectorDefaultAttributes(unittest.TestCase): """Tests for Vector class default attributes list.""" @@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase): assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" +class TestWeaviateVectorFactory(unittest.TestCase): + def test_init_vector_uses_existing_dataset_index_struct(self): + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}}, + index_struct=None, + ) + attributes = ["doc_id"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + config = mock_vector.call_args.kwargs["config"] + assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node" + assert mock_vector.call_args.kwargs["attributes"] == attributes + assert config.endpoint == "http://localhost:8080" + assert config.grpc_endpoint == "localhost:50051" + assert config.api_key == "api-key" + assert config.batch_size == 88 + assert dataset.index_struct is None + + def test_init_vector_generates_collection_and_updates_index_struct(self): + dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + attributes = ["doc_id", "doc_type"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100), + patch.object( + weaviate_vector_module.Dataset, + "gen_collection_name_by_id", + return_value="GeneratedCollection_Node", + ), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node" + assert json.loads(dataset.index_struct) == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "GeneratedCollection_Node"}, + } + + if __name__ == "__main__": unittest.main() diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 13285cdad0..3ba0628fe2 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -163,7 +163,7 @@ class TestDatasetDocumentStoreAddDocuments: with ( patch("core.rag.docstore.dataset_docstore.db") as mock_db, - patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + patch("core.rag.docstore.dataset_docstore.ModelManager.for_tenant") as mock_manager_class, ): mock_session = MagicMock() mock_db.session = mock_session diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index a0db25174d..bfa78fe565 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -15,8 +15,8 @@ import pytest from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding @@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -64,7 +65,7 @@ class TestCacheEmbeddingMultimodalDocuments: def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): """Test embedding a single multimodal document when cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "file123", "content": "test content"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: @@ -316,13 +317,14 @@ class TestCacheEmbeddingMultimodalQuery: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance def test_embed_multimodal_query_cache_miss(self, mock_model_instance): """Test embedding multimodal query when Redis cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) document = {"file_id": "file123"} vector = np.random.randn(1536) @@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -532,24 +535,13 @@ class TestCacheEmbeddingQueryErrors: class TestCacheEmbeddingInitialization: """Test suite for CacheEmbedding initialization.""" - def test_initialization_with_user(self): - """Test CacheEmbedding initialization with user parameter.""" - model_instance = Mock() - model_instance.model = "test-model" - model_instance.provider = "test-provider" - - cache_embedding = CacheEmbedding(model_instance, user="test-user") - - assert cache_embedding._model_instance == model_instance - assert cache_embedding._user == "test-user" - - def test_initialization_without_user(self): - """Test CacheEmbedding initialization without user parameter.""" + def test_initialization_sets_model_instance(self): + """Test CacheEmbedding initialization stores the provided model instance.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance) assert cache_embedding._model_instance == model_instance - assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 6e71f0c61f..392f0b458b 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,9 +53,9 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError, @@ -134,7 +134,7 @@ class TestCacheEmbeddingDocuments: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Python is a programming language"] # Mock database query to return no cached embedding (cache miss) @@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments: # Verify model was invoked with correct parameters mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=texts, - user="test-user", input_type=EmbeddingInputType.DOCUMENT, ) @@ -612,7 +611,7 @@ class TestCacheEmbeddingQuery: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is Python?" # Create embedding result @@ -651,7 +650,6 @@ class TestCacheEmbeddingQuery: # Verify model was invoked with QUERY input type mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user="test-user", input_type=EmbeddingInputType.QUERY, ) @@ -1568,25 +1566,16 @@ class TestEmbeddingEdgeCases: norm = np.linalg.norm(emb) assert abs(norm - 1.0) < 0.01 - def test_embed_query_with_user_context(self, mock_model_instance): - """Test query embedding with user context parameter. + def test_embed_query_uses_bound_model_instance(self, mock_model_instance): + """Test query embedding using the provided model instance. Verifies: - - User parameter is passed correctly to model - - User context is used for tracking/logging - - Embedding generation works with user context - - Context: - -------- - The user parameter is important for: - 1. Usage tracking per user - 2. Rate limiting per user - 3. Audit logging - 4. Personalization (in some models) + - Embedding generation works with the injected model instance + - Query input type is preserved + - No extra binding step is required at call time """ # Arrange - user_id = "user-12345" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is machine learning?" # Create embedding @@ -1620,24 +1609,20 @@ class TestEmbeddingEdgeCases: assert isinstance(result, list) assert len(result) == 1536 - # Verify user parameter was passed to model mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user=user_id, input_type=EmbeddingInputType.QUERY, ) - def test_embed_documents_with_user_context(self, mock_model_instance): - """Test document embedding with user context parameter. + def test_embed_documents_uses_bound_model_instance(self, mock_model_instance): + """Test document embedding using the provided model instance. Verifies: - - User parameter is passed correctly for document embeddings - - Batch processing maintains user context - - User tracking works across batches + - Batch processing uses the injected model instance + - Document input type is preserved """ # Arrange - user_id = "user-67890" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Document 1", "Document 2"] # Create embeddings @@ -1673,10 +1658,8 @@ class TestEmbeddingEdgeCases: # Assert assert len(result) == 2 - # Verify user parameter was passed mock_model_instance.invoke_text_embedding.assert_called_once() call_args = mock_model_instance.invoke_text_embedding.call_args - assert call_args.kwargs["user"] == user_id assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 2add12fd09..db49221583 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -164,6 +164,13 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="No page found"): app.check_crawl_status("job-1") + def test_check_crawl_status_completed_with_null_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": None, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") payload = {"status": "processing", "total": 5, "completed": 1, "data": []} @@ -203,6 +210,77 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="Error saving crawl data"): app.check_crawl_status("job-err") + def test_check_crawl_status_follows_pagination(self, mocker: MockerFixture): + """When status is completed and next is present, follow pagination to collect all pages.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + page2 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=2", + "data": [{"metadata": {"title": "p2", "description": "", "sourceURL": "https://p2"}, "markdown": "m2"}], + } + page3 = { + "status": "completed", + "total": 3, + "completed": 3, + "data": [{"metadata": {"title": "p3", "description": "", "sourceURL": "https://p3"}, "markdown": "m3"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(200, page2), _response(200, page3)]) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert [d["title"] for d in result["data"]] == ["p1", "p2", "p3"] + + def test_check_crawl_status_pagination_error_raises(self, mocker: MockerFixture): + """An error while fetching a paginated page raises an exception; no partial data is returned.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 2, + "completed": 2, + "next": "https://custom.firecrawl.dev/v2/crawl/job-99?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(500, {"error": "server error"})]) + + with pytest.raises(Exception, match="fetch next crawl page"): + app.check_crawl_status("job-99") + + def test_check_crawl_status_pagination_capped_at_total(self, mocker: MockerFixture): + """Pagination stops once pages_processed reaches total, even if next is present.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + # total=1: only the first page should be processed; next must not be followed + page1 = { + "status": "completed", + "total": 1, + "completed": 1, + "next": "https://custom.firecrawl.dev/v2/crawl/job-cap?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mock_get = mocker.patch("httpx.get", return_value=_response(200, page1)) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-cap") + + assert len(result["data"]) == 1 + mock_get.assert_called_once() # initial fetch only; next URL is not followed due to cap + def test_extract_common_fields_and_status_formatter(self): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index e6cc582398..c861871f02 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -4,11 +4,12 @@ from unittest.mock import Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -21,7 +22,7 @@ class TestParagraphIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -167,7 +168,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_with_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -178,7 +179,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_without_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -208,7 +209,7 @@ class TestParagraphIndexProcessor: def test_clean_economy_deletes_summaries_and_keywords( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( @@ -222,7 +223,7 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete.assert_called_once() def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: processor.clean(dataset, ["node-2"], with_keywords=True) @@ -267,7 +268,7 @@ class TestParagraphIndexProcessor: def test_index_list_chunks_economy( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", @@ -399,7 +400,9 @@ class TestParagraphIndexProcessor: model_instance.invoke_llm.return_value = self._llm_result("text summary") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -410,7 +413,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, usage = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -433,7 +436,9 @@ class TestParagraphIndexProcessor: image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -448,7 +453,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, _ = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -469,7 +474,9 @@ class TestParagraphIndexProcessor: image_file = SimpleNamespace() with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -486,7 +493,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() with pytest.raises(ValueError, match="Expected LLMResult"): ParagraphIndexProcessor.generate_summary( "tenant-1", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index 5c78cae7c1..b1ed735ee7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -19,7 +20,7 @@ class TestParentChildIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 99323eeec9..98c47bec8f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -6,6 +6,7 @@ import pytest from werkzeug.datastructures import FileStorage from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.models.document import AttachmentDocument, Document @@ -33,7 +34,7 @@ class TestQAIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,7 +208,7 @@ class TestQAIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="Q1", metadata={"answer": "A1"})] with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: @@ -298,7 +299,7 @@ class TestQAIndexProcessor: def test_index_requires_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) with ( diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade884..059876d410 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,9 +61,9 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument @@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument def create_mock_dataset( dataset_id: str | None = None, tenant_id: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", ) -> Mock: @@ -445,7 +445,7 @@ class TestIndexingRunnerTransform: """Mock all external dependencies for transform tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, ): yield { "db": mock_db, @@ -458,7 +458,7 @@ class TestIndexingRunnerTransform: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -482,7 +482,8 @@ class TestIndexingRunnerTransform: # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [ @@ -509,7 +510,7 @@ class TestIndexingRunnerTransform: assert len(result) == 2 assert result[0].page_content == "Chunk 1" assert result[1].page_content == "Chunk 2" - runner.model_manager.get_model_instance.assert_called_once_with( + model_manager.get_model_instance.assert_called_once_with( tenant_id=sample_dataset.tenant_id, provider=sample_dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -521,7 +522,8 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + model_manager = mock_dependencies["model_manager"].return_value + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() transformed_docs = [ @@ -539,14 +541,15 @@ class TestIndexingRunnerTransform: # Assert assert len(result) == 1 - runner.model_manager.get_model_instance.assert_not_called() + model_manager.get_model_instance.assert_not_called() def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs): """Test transformation with custom segmentation rules.""" # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})] @@ -586,7 +589,7 @@ class TestIndexingRunnerLoad: """Mock all external dependencies for load tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.current_app") as mock_app, patch("core.indexing_runner.threading.Thread") as mock_thread, patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor, @@ -605,7 +608,7 @@ class TestIndexingRunnerLoad: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -645,7 +648,8 @@ class TestIndexingRunnerLoad: runner = IndexingRunner() mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -664,7 +668,7 @@ class TestIndexingRunnerLoad: runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents) # Assert - runner.model_manager.get_model_instance.assert_called_once() + model_manager.get_model_instance.assert_called_once() # Verify executor was used for parallel processing assert mock_executor_instance.submit.called @@ -674,7 +678,7 @@ class TestIndexingRunnerLoad: """Test loading with economy indexing (keyword only).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -701,7 +705,7 @@ class TestIndexingRunnerLoad: # Arrange runner = IndexingRunner() sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - sample_dataset.indexing_technique = "high_quality" + sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # Add child documents for doc in sample_documents: @@ -714,7 +718,8 @@ class TestIndexingRunnerLoad: mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -754,7 +759,7 @@ class TestIndexingRunnerRun: with ( patch("core.indexing_runner.db") as mock_db, patch("core.indexing_runner.IndexProcessorFactory") as mock_factory, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.storage") as mock_storage, patch("core.indexing_runner.threading.Thread") as mock_thread, ): @@ -795,7 +800,7 @@ class TestIndexingRunnerRun: mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) @@ -949,7 +954,7 @@ class TestIndexingRunnerRun: mock_dependencies["db"].session.get.side_effect = get_side_effect mock_dataset = Mock(spec=Dataset) - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index b150d677f1..415597f336 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -28,7 +28,7 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: @@ -57,7 +57,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -352,12 +352,14 @@ class TestRerankModelRunner: # Assert: Empty result is returned assert len(result) == 0 - def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents): - """Test that user parameter is passed to model invocation. + def test_run_uses_bound_model_instance( + self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager + ): + """Test that rerank uses the bound model instance directly. Verifies: - - User ID is correctly forwarded to the model - - Model receives all expected parameters + - The injected model instance is used for invocation + - No late rebinding occurs through ModelManager.get_model_instance """ # Arrange: Mock rerank result mock_rerank_result = RerankResult( @@ -368,16 +370,18 @@ class TestRerankModelRunner: ) mock_model_instance.invoke_rerank.return_value = mock_rerank_result - # Act: Run reranking with user parameter + # Act: Run reranking result = rerank_runner.run( query="test", documents=sample_documents, - user="user123", ) - # Assert: User parameter is passed to model + # Assert: The injected model instance is invoked directly. + assert len(result) == 1 + mock_model_manager.return_value.get_model_instance.assert_not_called() call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs - assert call_kwargs["user"] == "user123" + assert call_kwargs["query"] == "test" + assert "user" not in call_kwargs class _ForwardingBaseRerankRunner(BaseRerankRunner): @@ -387,7 +391,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: return super().run( @@ -395,7 +398,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents=documents, score_threshold=score_threshold, top_n=top_n, - user=user, query_type=query_type, ) @@ -424,7 +426,7 @@ class TestRerankModelRunnerMultimodal: Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), ] - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) @@ -441,7 +443,7 @@ class TestRerankModelRunnerMultimodal: ) with ( - patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm, patch.object( rerank_runner, "fetch_multimodal_rerank", @@ -539,8 +541,10 @@ class TestRerankModelRunnerMultimodal: ) mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + session = MagicMock() + session.query.return_value = query_chain with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), ): result, unique_documents = rerank_runner.fetch_multimodal_rerank( @@ -548,7 +552,6 @@ class TestRerankModelRunnerMultimodal: documents=[text_doc], score_threshold=0.2, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -557,7 +560,7 @@ class TestRerankModelRunnerMultimodal: invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE assert invoke_kwargs["docs"][0]["content"] == "text-content" - assert invoke_kwargs["user"] == "user-1" + assert "user" not in invoke_kwargs def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): query_chain = Mock() @@ -595,7 +598,7 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager: yield mock_manager @pytest.fixture @@ -1145,7 +1148,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1257,7 +1260,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1527,7 +1530,7 @@ class TestRerankEdgeCases: # Mock dependencies with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1598,7 +1601,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1673,7 +1676,7 @@ class TestRerankPerformance: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1715,7 +1718,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1824,7 +1827,7 @@ class TestRerankErrorHandling: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 665e98bd9c..a7e62e7b0a 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -35,9 +35,10 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset +from models.enums import CreatorUserRole # ==================== Helper Functions ==================== @@ -3747,6 +3748,24 @@ class TestDatasetRetrievalAdditionalHelpers: mock_session.add_all.assert_called() mock_session.commit.assert_called() + def test_on_query_normalizes_workflow_end_user_role(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query="python", + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="end-user", + user_id="u1", + ) + + mock_session.add_all.assert_called_once() + added_queries = mock_session.add_all.call_args.args[0] + + assert len(added_queries) == 1 + assert added_queries[0].created_by_role == CreatorUserRole.END_USER + mock_session.commit.assert_called_once() + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: usage = LLMUsage.empty_usage() chunk_1 = SimpleNamespace( @@ -3836,7 +3855,7 @@ class TestDatasetRetrievalAdditionalHelpers: model_instance.model_type_instance.get_model_schema.return_value = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_manager, patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, ): mock_manager.return_value.get_model_instance.return_value = model_instance @@ -4222,11 +4241,12 @@ class TestKnowledgeRetrievalCoverage: with ( patch.object(retrieval, "_check_knowledge_rate_limit"), patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: retrieval.knowledge_retrieval(request) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert error_cls in type(exc_info.value).__name__ @@ -4279,9 +4299,13 @@ class TestRetrieveCoverage: ), ) model_config = self._build_model_config() - model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None - with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: - mock_model_manager.return_value.get_model_instance.return_value = Mock() + model_instance = Mock() + model_instance.model_name = "gpt-4" + model_instance.credentials = {"api_key": "secret"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = model_instance result = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4294,8 +4318,58 @@ class TestRetrieveCoverage: hit_callback=Mock(), message_id="m1", ) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert result == (None, []) + def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config( + self, retrieval: DatasetRetrieval + ) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ), + ) + model_config = self._build_model_config(features=[]) + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL]) + bound_bundle = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {"api_key": "secret"} + bound_model_instance.provider_model_bundle = bound_bundle + bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve, + ): + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_single_retrieve.assert_called_once() + assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER + assert model_config.provider_model_bundle is bound_bundle + assert model_config.credentials == {"api_key": "secret"} + assert model_config.model_schema is bound_schema + assert context == "" + assert files == [] + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: retrieve_config = DatasetRetrieveConfigEntity( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, @@ -4312,12 +4386,17 @@ class TestRetrieveCoverage: extra={"title": "External", "dataset_name": "External DS"}, ) with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance context, files = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4402,7 +4481,7 @@ class TestRetrieveCoverage: hit_callback = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), @@ -4413,7 +4492,14 @@ class TestRetrieveCoverage: patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.TOOL_CALL] + ) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] context, files = retrieval.retrieve( app_id="app-1", @@ -4800,8 +4886,8 @@ class TestInternalHooksCoverage: dataset_docs = [ SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), - SimpleNamespace(id="doc-c", doc_form="qa_model"), - SimpleNamespace(id="doc-d", doc_form="qa_model"), + SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX), + SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX), ] child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index cfa9094e12..43c521dcfd 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,7 +1,7 @@ from unittest.mock import Mock from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index e429563739..c56528cf55 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -3,8 +3,9 @@ from unittest.mock import Mock, patch from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.model_entities import ModelType class TestReactMultiDatasetRouter: @@ -87,6 +88,7 @@ class TestReactMultiDatasetRouter: model_config = Mock() model_config.mode = "chat" model_config.parameters = {"temperature": 0.1} + model_instance = Mock() usage = LLMUsage.empty_usage() tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] tools[0].name = "dataset-1" @@ -108,13 +110,14 @@ class TestReactMultiDatasetRouter: dataset_id, returned_usage = router._react_invoke( query="python", model_config=model_config, - model_instance=Mock(), + model_instance=model_instance, tools=tools, user_id="u1", tenant_id="t1", ) mock_chat_prompt.assert_called_once() + assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance assert dataset_id == "dataset-2" assert returned_usage == usage @@ -162,7 +165,11 @@ class TestReactMultiDatasetRouter: model_instance = Mock() model_instance.invoke_llm.return_value = iter([chunk]) - with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + with ( + patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager, + patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance text, returned_usage = router._invoke_llm( completion_param={"temperature": 0.1}, model_instance=model_instance, @@ -174,6 +181,13 @@ class TestReactMultiDatasetRouter: assert text == "part" assert returned_usage == usage + mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1") + mock_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id="t1", + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) mock_deduct.assert_called_once() def test_handle_invoke_result_with_empty_usage(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e7eecfa297..2735ec512f 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType +from graphon.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 2a83a4e802..05b4f3a053 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( +from core.repositories.factory import OrderConfig +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +181,10 @@ class TestCeleryWorkflowNodeExecutionRepository: repo.save(sample_workflow_node_execution) @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") - def test_get_by_workflow_run_from_cache( + def test_get_by_workflow_execution_from_cache( self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution ): - """Test that get_by_workflow_run retrieves executions from cache.""" + """Test that get_by_workflow_execution retrieves executions from cache.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -195,18 +195,18 @@ class TestCeleryWorkflowNodeExecutionRepository: # Save execution to cache first repo.save(sample_workflow_node_execution) - workflow_run_id = sample_workflow_node_execution.workflow_execution_id + workflow_execution_id = sample_workflow_node_execution.workflow_execution_id order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) # Verify results were retrieved from cache assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id assert result[0] is sample_workflow_node_execution - def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): - """Test get_by_workflow_run without order configuration.""" + def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_execution without order configuration.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -214,7 +214,7 @@ class TestCeleryWorkflowNodeExecutionRepository: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - result = repo.get_by_workflow_run("workflow-run-id") + result = repo.get_by_workflow_execution("workflow-run-id") # Should return empty list since nothing in cache assert len(result) == 0 @@ -236,7 +236,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert sample_workflow_node_execution.id in repo._execution_cache # Test retrieving from cache - result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id) assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id @@ -251,12 +251,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create multiple executions for the same workflow - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.START, @@ -269,7 +269,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.LLM, @@ -285,10 +285,10 @@ class TestCeleryWorkflowNodeExecutionRepository: # Verify both are cached and mapped assert len(repo._execution_cache) == 2 - assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2 # Test retrieval - result = repo.get_by_workflow_run(workflow_run_id) + result = repo.get_by_workflow_execution(workflow_execution_id) assert len(result) == 2 @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") @@ -302,12 +302,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create executions with different indices - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.START, @@ -320,7 +320,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.LLM, @@ -336,14 +336,14 @@ class TestCeleryWorkflowNodeExecutionRepository: # Test ascending order order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 1 assert result[1].index == 2 # Test descending order order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 2 assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fe9eed0307..48327c3913 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -11,9 +11,12 @@ import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 9af4d12664..8be1ac318c 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -14,16 +14,18 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, MemberRecipient, +) +from graphon.nodes.human_input.entities import ( + FormDefinition, UserAction, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -89,9 +91,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="external@example.com"), ], ), @@ -125,9 +127,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="missing-member"), + MemberRecipient(reference_id="missing-member"), ExternalRecipient(email="external@example.com"), ], ), @@ -156,7 +158,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[], ), ) @@ -182,7 +184,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ ExternalRecipient(email="external@example.com"), ExternalRecipient(email="external@example.com"), @@ -212,9 +214,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="shared@example.com"), ], ), @@ -243,7 +245,7 @@ class TestHumanInputFormRepositoryImplHelpers: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[ExternalRecipient(email="external@example.com")], ), subject="subject", @@ -272,7 +274,7 @@ def _make_form_definition() -> str: inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ).model_dump_json() @@ -421,22 +423,22 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.id == form.id - assert entity.web_app_token == "token-123" + assert entity.submission_token == "token-123" assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id="run-1") - assert repo.get_form("run-1", "node-1") is None + assert repo.get_form("node-1") is None def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( @@ -451,9 +453,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is False @@ -476,9 +478,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is True diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 4116e8b4a5..1297a95df1 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -11,6 +11,8 @@ from unittest.mock import MagicMock import pytest from core.repositories.human_input_repository import ( + FormCreateParams, + FormNotFoundError, HumanInputFormRecord, HumanInputFormRepositoryImpl, HumanInputFormSubmissionRepository, @@ -19,18 +21,16 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType @@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None: assert entity.token == "tok" -def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: +def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None: form = _DummyForm( id="f1", workflow_run_id="run", @@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No ) entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] - assert entity.web_app_token == "ctok" + assert entity.submission_token == "ctok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] - assert entity.web_app_token == "wtok" + assert entity.submission_token == "wtok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] - assert entity.web_app_token is None + assert entity.submission_token is None def test_form_entity_submitted_data_parsed() -> None: @@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), subject="s", body="b", @@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc session=MagicMock(), form_id="f", delivery_id="d", - recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]), ) assert recipients == ["ok"] @@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m form_id="f", delivery_id="d", recipients_config=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), ) assert recipients == ["ok"] @@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - assert repo.get_form("run", "node") is None + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + assert repo.get_form("node") is None form = _DummyForm( id="f1", @@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - entity = repo.get_form("run", "node") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + entity = repo.get_form("node") assert entity is not None assert entity.id == "f1" assert entity.recipients[0].id == "r1" @@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M session = _FakeSession() _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + repo = HumanInputFormRepositoryImpl( + tenant_id="tenant", + app_id="app", + workflow_execution_id="run", + invoke_source="debugger", + submission_actor_id="acc-1", + ) form_config = HumanInputNodeData( title="Title", @@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M user_actions=[UserAction(id="submit", title="Submit")], ) params = FormCreateParams( - app_id="app", - workflow_execution_id="run", + workflow_execution_id=None, node_id="node", form_config=form_config, rendered_content="

hello

", @@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M display_in_ui=True, resolved_default_values={}, form_kind=HumanInputFormKind.RUNTIME, - console_recipient_required=True, - console_creator_account_id="acc-1", - backstage_recipient_required=True, ) entity = repo.create_form(params) assert entity.id == "form-id" assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) # Console token should take precedence when console recipient is present. - assert entity.web_app_token == "token-console" + assert entity.submission_token == "token-console" assert len(entity.recipients) == 3 diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index 232ab07882..6cb3c3c6ac 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -7,7 +7,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from graphon.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index c7af32789b..6af7b02d4c 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -15,6 +15,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from configs import dify_config +from core.repositories.factory import OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, _deterministic_json_dump, @@ -22,13 +23,12 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import ( - NodeType, +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom @@ -67,7 +67,7 @@ def _execution( index=1, predecessor_node_id=None, node_id="node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Title", inputs=inputs, outputs=outputs, @@ -387,7 +387,7 @@ def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node" - db_model.node_type = NodeType.LLM + db_model.node_type = BuiltinNodeTypes.LLM db_model.title = "t" db_model.inputs = json.dumps({"trunc": "i"}) db_model.process_data = json.dumps({"trunc": "p"}) @@ -441,7 +441,7 @@ def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest. db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node" - db_model.node_type = NodeType.LLM + db_model.node_type = BuiltinNodeTypes.LLM db_model.title = "t" db_model.inputs = json.dumps({"i": 1}) db_model.process_data = json.dumps({"p": 2}) @@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> lambda max_workers: FakeExecutor(), ) - result = repo.get_by_workflow_run("run", order_config=None) + result = repo.get_by_workflow_execution("run", order_config=None) assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 456c3dde12..abdbc72085 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -10,11 +10,11 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index eeab81a178..5af1376a0a 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -17,11 +17,11 @@ from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 0000000000..36e8e1bbb1 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 0000000000..a68fce5e7f --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_metric_case_routes_to_celery_task( + self, + mock_ee_enabled: MagicMock, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_tool_execution_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_moderation_check_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 251d6fd25e..f17927f16b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,6 @@ import json -from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 92e4b58473..afea9144c0 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -6,8 +6,8 @@ from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 90ed1647aa..b19a21d7f4 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -12,9 +12,9 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormOption, diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 69567c54eb..7f6a50af99 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,12 +1,26 @@ -from unittest.mock import Mock, PropertyMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from models.provider import LoadBalancingModelConfig, ProviderModelSetting +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel +from models.provider_ids import ModelProviderID + + +def _build_provider_manager(mocker: MockerFixture) -> ProviderManager: + return ProviderManager(model_runtime=mocker.Mock()) + + +def _build_session_context(session: Mock) -> MagicMock: + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return session_cm @pytest.fixture @@ -28,7 +42,7 @@ def mock_provider_entity(): return mock_entity -def test__to_model_settings(mock_provider_entity): +def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -69,7 +83,7 @@ def test__to_model_settings(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -89,7 +103,7 @@ def test__to_model_settings(mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mock_provider_entity): +def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -119,7 +133,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -137,7 +151,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mock_provider_entity): +def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -176,7 +190,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -194,7 +208,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test_get_default_model_uses_first_available_active_model(): +def test_get_default_model_uses_first_available_active_model(mocker: MockerFixture): mock_session = Mock() mock_session.scalar.return_value = None @@ -204,7 +218,7 @@ def test_get_default_model_uses_first_available_active_model(): Mock(model="gpt-4", provider=Mock(provider="openai")), ] - manager = ProviderManager() + manager = _build_provider_manager(mocker) with ( patch("core.provider_manager.db.session", mock_session), patch.object(manager, "get_configurations", return_value=provider_configurations), @@ -228,3 +242,345 @@ def test_get_default_model_uses_first_available_active_model(): assert saved_default_model.model_name == "gpt-3.5-turbo" assert saved_default_model.provider_name == "openai" mock_session.commit.assert_called_once() + + +def test_get_default_model_returns_none_when_no_default_or_active_models(mocker: MockerFixture): + mock_session = Mock() + mock_session.scalar.return_value = None + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is None + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_factory_cls.assert_not_called() + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_get_default_model_uses_injected_runtime_for_existing_default_record(mocker: MockerFixture): + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="openai", + model_name="gpt-4", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result is not None + assert result.model == "gpt-4" + assert result.provider.provider == "openai" + + +def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_records = {"openai": [SimpleNamespace(provider_name="openai")]} + provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]} + preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")} + + with ( + patch.object(manager, "_get_all_providers", return_value=provider_records), + patch.object(manager, "_init_trial_provider_records", return_value=provider_records), + patch.object(manager, "_get_all_provider_models", return_value=provider_model_records), + patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_providers.return_value = [] + + result = manager.get_configurations("tenant-id") + + expected_alias = str(ModelProviderID("openai")) + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result.tenant_id == "tenant-id" + assert expected_alias in provider_records + assert expected_alias in provider_model_records + assert expected_alias in preferred_provider_records + + +@pytest.mark.parametrize( + ("provider_name", "expected_provider_names"), + [ + ("openai", ["openai", "langgenius/openai/openai"]), + ("langgenius/openai/openai", ["langgenius/openai/openai", "openai"]), + ("langgenius/gemini/google", ["langgenius/gemini/google", "google"]), + ], +) +def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, expected_provider_names: list[str]): + assert ProviderManager._get_provider_names(provider_name) == expected_provider_names + + +def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + +def test_get_configurations_binds_manager_runtime_to_provider_configuration( + mocker: MockerFixture, mock_provider_entity +): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}), + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration), + ): + manager.get_configurations("tenant-id") + + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + model_type_instance = Mock() + provider_configuration.get_model_type_instance.return_value = model_type_instance + expected_bundle = Mock() + + with ( + patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}), + patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle, + ): + result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + provider_configuration.get_model_type_instance.assert_called_once_with(ModelType.LLM) + mock_bundle.assert_called_once_with( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + assert result is expected_bundle + + +def test_get_first_provider_first_model_returns_none_when_no_models(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == (None, None) + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False) + + +def test_get_first_provider_first_model_returns_first_model_and_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-4", provider=Mock(provider="openai")), + Mock(model="gpt-4o", provider=Mock(provider="openai")), + ] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == ("openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_model(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + with pytest.raises(ValueError, match="Model gpt-3.5-turbo does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + + +def test_update_default_model_record_updates_existing_record(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-3.5-turbo")] + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="anthropic", + model_name="claude-3-sonnet", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + assert result is existing_default_model + assert existing_default_model.provider_name == "openai" + assert existing_default_model.model_name == "gpt-3.5-turbo" + mock_session.commit.assert_called_once() + mock_session.add.assert_not_called() + + +def test_update_default_model_record_creates_record_with_origin_model_type(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + mock_session = Mock() + mock_session.scalar.return_value = None + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + mock_session.add.assert_called_once() + created_default_model = mock_session.add.call_args.args[0] + assert result is created_default_model + assert created_default_model.tenant_id == "tenant-id" + assert created_default_model.provider_name == "openai" + assert created_default_model.model_name == "gpt-4" + assert created_default_model.model_type == ModelType.LLM.to_origin_model_type() + mock_session.commit.assert_called_once() + + +def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None: + session = Mock() + openai_provider = SimpleNamespace(provider_name="openai") + gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google") + session.scalars.return_value = [openai_provider, gemini_provider] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_providers("tenant-id") + + assert list(result[str(ModelProviderID("openai"))]) == [openai_provider] + assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider] + + +@pytest.mark.parametrize( + "method_name", + [ + "_get_all_provider_models", + "_get_all_provider_model_settings", + "_get_all_provider_model_credentials", + ], +) +def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None: + session = Mock() + openai_primary = SimpleNamespace(provider_name="openai") + openai_secondary = SimpleNamespace(provider_name="openai") + anthropic_record = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = getattr(ProviderManager, method_name)("tenant-id") + + assert list(result["openai"]) == [openai_primary, openai_secondary] + assert list(result["anthropic"]) == [anthropic_record] + + +def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None: + session = Mock() + openai_preference = SimpleNamespace(provider_name="openai") + anthropic_preference = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_preference, anthropic_preference] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_preferred_model_providers("tenant-id") + + assert result == { + "openai": openai_preference, + "anthropic": anthropic_preference, + } + + +def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None: + with ( + patch("core.provider_manager.redis_client.get", return_value=b"False"), + patch("core.provider_manager.FeatureService.get_features") as mock_get_features, + patch("core.provider_manager.Session") as mock_session_cls, + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + assert result == {} + mock_get_features.assert_not_called() + mock_session_cls.assert_not_called() + + +def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: + session = Mock() + openai_config = SimpleNamespace(provider_name="openai") + anthropic_config = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_config, anthropic_config] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.redis_client.get", return_value=None), + patch("core.provider_manager.redis_client.setex") as mock_setex, + patch( + "core.provider_manager.FeatureService.get_features", + return_value=SimpleNamespace(model_load_balancing_enabled=True), + ), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True") + assert list(result["openai"]) == [openai_config] + assert list(result["anthropic"]) == [anthropic_config] diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index f123f60a34..1ff81f6120 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -12,7 +12,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from dify_graph.model_runtime.entities.message_entities import UserPromptMessage +from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): @@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool): yield self.create_text_message("ok") -def _build_tool() -> _BuiltinDummyTool: +def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool: entity = ToolEntity( identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), parameters=[], ) - runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER) return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) @@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type(): def test_invoke_model_calls_model_invocation_utils_invoke(): - tool = _build_tool() + tool = _build_tool(user_id="runtime-user") with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: assert ( tool.invoke_model( @@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke(): ) == "result" ) - mock_invoke.assert_called_once() + mock_invoke.assert_called_once_with( + user_id="u1", + tenant_id="tenant-1", + tool_type=ToolProviderType.BUILT_IN, + tool_name="tool-a", + prompt_messages=[UserPromptMessage(content="hello")], + caller_user_id="runtime-user", + ) def test_get_max_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + tool = _build_tool(user_id="runtime-user") + with patch( + "core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096 + ) as mock_get: assert tool.get_max_tokens() == 4096 + mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user") def test_get_prompt_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + tool = _build_tool(user_id="runtime-user") + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="runtime-user", + ) + + +def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing(): + tool = _build_tool() + + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id=None, + ) def test_runtime_none_raises(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 62cfb6ce5b..9ac280e31a 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -1,6 +1,8 @@ from __future__ import annotations +import calendar import math +from datetime import date from types import SimpleNamespace import pytest @@ -25,8 +27,8 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from dify_graph.file.enums import FileType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.file.enums import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: @@ -98,7 +100,13 @@ def test_timezone_conversion_tool(): def test_weekday_tool(): weekday_tool = _build_builtin_tool(WeekdayTool) valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text - assert "January 1, 2024" in valid + expected_date = date(2024, 1, 1) + expected_message = ( + f"{calendar.month_name[expected_date.month]} " + f"{expected_date.day}, {expected_date.year} " + f"is {calendar.day_name[expected_date.weekday()]}." + ) + assert valid == expected_message invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ 0 ].message.text @@ -186,13 +194,19 @@ def test_asr_invalid_file(): def test_asr_valid_file_invocation(monkeypatch): asr = _build_builtin_tool(ASRTool) - model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})() model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") - monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + captured_manager_kwargs = {} + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant", + lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager, + ) audio_file = SimpleNamespace(type=FileType.AUDIO) ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text assert ok == "transcript" + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_asr_available_models_and_runtime_parameters(monkeypatch): @@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch): def test_tts_invoke_returns_messages(monkeypatch): tts = _build_builtin_tool(TTSTool) + captured_manager_kwargs = {} voices_model_instance = type( "TTSM", (), @@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **kwargs: ( + captured_manager_kwargs.update(kwargs) + or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})() + ), ) messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_tts_get_available_models_requires_runtime(): @@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), ) with pytest.raises(ValueError, match="no voice available"): list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index a5242a78c5..353988d7a6 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse import pytest -from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature +from core.tools.signature import ( + get_signed_file_url_for_plugin, + sign_tool_file, + sign_upload_file, + verify_plugin_file_signature, + verify_tool_file_signature, +) def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: @@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] + + +def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload/for-plugin" + assert query["tenant_id"] == ["tenant-id"] + assert query["user_id"] == ["user-id"] + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is True + ) + + +def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + query = parse_qs(urlparse(url).query) + + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100) + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is False + ) diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index cca8254dd6..b3442636b7 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -14,6 +14,7 @@ import httpx import pytest from core.tools.tool_file_manager import ToolFileManager +from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: @@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: def test_get_file_generator_returns_stream_when_found() -> None: # Arrange manager = ToolFileManager() - tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + tool_file = SimpleNamespace( + id="tool123", + file_key="k2", + mimetype="image/png", + original_url=None, + name="image.png", + size=12, + ) session = Mock() session.query.return_value.where.return_value.first.return_value = tool_file @@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None: with patch("core.tools.tool_file_manager.storage") as storage: stream = iter([b"a", b"b"]) storage.load_stream.return_value = stream - with ( - _patch_session_factory(session), - patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), - ): + with _patch_session_factory(session): result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") assert list(result_stream) == [b"a", b"b"] - assert result_file == "validated-file" + assert result_file is not None + assert result_file.related_id == "tool123" + assert result_file.mime_type == "image/png" + assert result_file.transfer_method == FileTransferMethod.TOOL_FILE diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 0f73e22654..844bc01e29 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ToolProviderType, ) @@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters(): tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} @@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters(): tenant_id="tenant-1", app_id="app-1", agent_tool=agent_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert result is tool_runtime assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=None, + ) def test_get_workflow_runtime_apply_runtime_parameters(): @@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters(): ) tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} @@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters(): app_id="app-1", node_id="node-1", workflow_tool=workflow_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert workflow_result is tool_runtime2 assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=None, + ) def test_get_agent_runtime_raises_when_runtime_missing(): @@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime: result = ToolManager.get_tool_runtime_from_plugin( tool_type=ToolProviderType.API, tenant_id="tenant-1", provider="api-1", tool_name="search", tool_parameters={"q": "hello", "llm": "ignore"}, + user_id="user-1", ) assert result is tool_entity assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=None, + ) def test_hardcoded_provider_icon_success(): diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 5ceaa08893..ae5638784c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -110,7 +110,7 @@ def test_encrypt_tool_parameters(): assert encrypted["plain"] == "x" -def test_decrypt_tool_parameters_cache_hit_and_miss(): +def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch): manager = _build_manager() with ( @@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache(): mock_delete.assert_called_once() -def test_configuration_manager_decrypt_suppresses_errors(): +def test_configuration_manager_decrypt_suppresses_errors(monkeypatch): manager = _build_manager() with ( patch.object(ToolParameterCache, "get", return_value=None), diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index af3cdddd5f..6454a5bcd1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): # meta is preserved (still contains mime_type: None) assert "mime_type" in (o.meta or {}) assert o.meta["mime_type"] is None + assert o.meta["tool_file_id"] == "fake-tool-file-id" + + +def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta(): + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"), + meta={}, + ) + + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + assert len(out) == 1 + assert out[0].meta["tool_file_id"] == "existing-tool-file" diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 2acae889b2..a4a563a4a1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -15,8 +15,8 @@ from unittest.mock import Mock, patch import pytest from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: if error_match: with pytest.raises(InvokeModelError, match=error_match): - ModelInvocationUtils.get_max_llm_context_tokens("tenant") + ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") else: - assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected + + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1") def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None) def test_invoke_success_and_error_mappings(): @@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings(): tool_type="builtin", tool_name="tool-a", prompt_messages=[], + caller_user_id="caller-1", ) assert response.message.content == "ok" assert db_mock.session.add.call_count == 1 assert db_mock.session.commit.call_count == 2 + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1") @pytest.mark.parametrize( @@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): @@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected): tool_name="tool-a", prompt_messages=[], ) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1") diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index dd79b79718..43f3fbd5c9 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -3,7 +3,7 @@ import pytest from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index dd140cbb27..b147d7fcdb 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -13,7 +13,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cc00f79698..72a73dd936 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -24,7 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FILE_MODEL_IDENTITY +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: @@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): """Transform args into parameters and files payloads.""" tool = _setup_transform_args_tool(monkeypatch) + build_file_from_stored_mapping = MagicMock( + side_effect=[ + SimpleNamespace( + transfer_method=FileTransferMethod.TOOL_FILE, + type=FileType.IMAGE, + reference="tool-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.LOCAL_FILE, + type=FileType.DOCUMENT, + reference="upload-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.REMOTE_URL, + type=FileType.DOCUMENT, + reference=None, + generate_url=lambda: "https://example.com/a.pdf", + ), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.build_file_from_stored_mapping", + build_file_from_stored_mapping, + ) params, files = tool._transform_args( { @@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + assert build_file_from_stored_mapping.call_count == 3 + assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list) def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index bcb1d745e3..ee7a3d9c96 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -26,7 +26,7 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 91259c9a45..72052c8c05 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -5,11 +5,12 @@ import pytest from pydantic import BaseModel from core.helper import encrypter -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import VariablePool +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -25,13 +26,13 @@ from dify_graph.variables.segments import ( StringSegment, get_segment_discriminator, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import ( +from graphon.variables.types import SegmentType +from graphon.variables.utils import ( dumps_with_segments, segment_orjson_default, to_selector, ) -from dify_graph.variables.variables import ( +from graphon.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, @@ -48,14 +49,28 @@ from dify_graph.variables.variables import ( ) +def _build_variable_pool( + *, + system_variables: list[Variable] | None = None, + environment_variables: list[Variable] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables or [], + environment_variables=environment_variables or [], + ), + ) + return variable_pool + + def test_segment_group_to_text(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="fake-user-id"), environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], - conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( @@ -71,11 +86,8 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="1", app_id="1", workflow_id="1"), ) template = "Hello, world!" segments_group = variable_pool.convert_template(template) @@ -84,12 +96,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(system_variables=build_system_variables(user_id="fake-user-id")) template = "{{#sys.user_id#}}" segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" @@ -116,7 +123,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, @@ -190,7 +196,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_segment.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: @@ -234,7 +239,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_variable.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index bb234d9bbd..d4e862220a 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,8 +1,8 @@ import pytest -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import StringSegment +from graphon.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 41ce483447..14f9b2991d 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -10,10 +10,10 @@ from typing import Any import pytest -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -22,7 +22,7 @@ from dify_graph.variables.segments import ( ObjectSegment, StringSegment, ) -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dd0fe2e65a..dae5e1ce98 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from dify_graph.variables import ( +from graphon.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +11,7 @@ from dify_graph.variables import ( SegmentType, StringVariable, ) -from dify_graph.variables.variables import VariableBase +from graphon.variables.variables import VariableBase def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index d09b8397c3..3ce4bb753b 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from dify_graph.context.execution_context import ( +from context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() def teardown_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from dify_graph.context import ContextProviderNotFoundError + from context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 22792eb5b3..ef5500b72f 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock, patch import pytest -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from dify_graph.variables.variables import StringVariable +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from graphon.variables.variables import StringVariable class StubCoordinator: @@ -23,6 +23,17 @@ class StubCoordinator: class TestGraphRuntimeState: + def test_execution_context_defaults_to_empty_context(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + + with state.execution_context: + assert state.execution_context is not None + + state.execution_context = None + + with state.execution_context: + assert state.execution_context is not None + def test_property_getters_and_setters(self): # FIXME(-LAN-): Mock VariablePool if needed variable_pool = VariablePool() @@ -117,7 +128,7 @@ class TestGraphRuntimeState: queue = state.ready_queue - from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue + from graphon.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) @@ -126,7 +137,7 @@ class TestGraphRuntimeState: execution = state.graph_execution - from dify_graph.graph_engine.domain.graph_execution import GraphExecution + from graphon.graph_engine.domain.graph_execution import GraphExecution assert isinstance(execution, GraphExecution) assert execution.workflow_id == "" @@ -141,7 +152,7 @@ class TestGraphRuntimeState: mock_graph = MagicMock() with patch( - "dify_graph.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True + "graphon.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True ) as coordinator_cls: coordinator_instance = coordinator_cls.return_value state.configure(graph=mock_graph) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py index 158f7018b5..856ec959b7 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -5,7 +5,7 @@ Tests for PauseReason discriminated union serialization/deserialization. import pytest from pydantic import BaseModel, ValidationError -from dify_graph.entities.pause_reason import ( +from graphon.entities.pause_reason import ( HumanInputRequired, PauseReason, SchedulingPause, diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py index 2d4c7f7b77..e8304b9bcd 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ b/api/tests/unit_tests/core/workflow/entities/test_template.py @@ -1,6 +1,6 @@ """Tests for template module.""" -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment +from graphon.nodes.base.template import Template, TextSegment, VariableSegment class TestTemplate: diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index 6100ebede5..7e08751683 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -1,5 +1,5 @@ -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ( +from graphon.runtime import VariablePool +from graphon.variables.segments import ( BooleanSegment, IntegerSegment, NoneSegment, @@ -126,7 +126,7 @@ class TestVariablePoolGetNotModifyVariableDictionary: def test_get_should_not_modify_variable_dictionary(self): pool = VariablePool.empty() pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 1 # only contains `sys` node id + assert len(pool.variable_dictionary) == 0 assert "start" not in pool.variable_dictionary pool = VariablePool.empty() diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py index 216e64db8d..5e697f22f3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py @@ -8,8 +8,8 @@ from typing import Any import pytest -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes +from graphon.entities.workflow_node_execution import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes class TestWorkflowNodeExecutionProcessDataTruncation: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py index 24bd9ccbed..b138a7dfdc 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph.py @@ -2,10 +2,10 @@ from unittest.mock import Mock -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph.edge import Edge -from dify_graph.graph.graph import Graph -from dify_graph.nodes.base.node import Node +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState +from graphon.graph.edge import Edge +from graphon.graph.graph import Graph +from graphon.nodes.base.node import Node def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py index 64c2eee776..f3eaa1d686 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph import Graph -from dify_graph.nodes.base.node import Node +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph import Graph +from graphon.nodes.base.node import Node def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index 75de07bd8b..3620a20e56 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -5,11 +5,11 @@ from typing import Any import pytest from core.workflow.node_factory import DifyNodeFactory -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from graphon.graph import Graph +from graphon.graph.validation import GraphValidationError +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -63,7 +63,7 @@ def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], ), diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index e94ad74eb0..bfd0b48392 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,14 +6,14 @@ from dataclasses import dataclass import pytest -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType +from graphon.graph import Graph +from graphon.graph.validation import GraphValidationError +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -96,7 +96,7 @@ def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: invoke_from="service-api", call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) return factory, graph_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 40ed61eb02..960fef7d43 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -68,7 +68,7 @@ print(f"Success rate: {suite_result.success_rate:.1f}%") #### Event Sequence Validation ```python -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, @@ -376,39 +376,39 @@ See `test_mock_example.py` for comprehensive examples including: ```bash # Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py # Run with specific test patterns -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -k "test_echo" +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -k "test_echo" # Run with verbose output -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -v +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -v ``` ### Mock System Tests ```bash # Run auto-mock system tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_auto_mock_system.py +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_auto_mock_system.py # Run examples -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_example.py +uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_example.py # Run simple validation -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_simple.py +uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_simple.py ``` ### All Tests ```bash # Run all graph engine tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ +uv run pytest api/tests/unit_tests/graphon/graph_engine/ # Run with coverage -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ --cov=dify_graph.graph_engine +uv run pytest api/tests/unit_tests/graphon/graph_engine/ --cov=graphon.graph_engine # Run in parallel -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ -n auto +uv run pytest api/tests/unit_tests/graphon/graph_engine/ -n auto ``` ## Troubleshooting diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 4dec618e49..795362b158 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,15 +3,15 @@ import json from unittest.mock import MagicMock -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import ( +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.entities.commands import ( AbortCommand, CommandType, GraphEngineCommand, UpdateVariablesCommand, VariableUpdate, ) -from dify_graph.variables import IntegerVariable, StringVariable +from graphon.variables import IntegerVariable, StringVariable class TestRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 6f821ba799..cacbe9ba4e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -2,18 +2,18 @@ from __future__ import annotations -from dify_graph.entities.base_node_data import RetryConfig -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.runtime import GraphRuntimeState, VariablePool +from graphon.entities.base_node_data import RetryConfig +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine.domain.graph_execution import GraphExecution +from graphon.graph_engine.event_management.event_handlers import EventHandler +from graphon.graph_engine.event_management.event_manager import EventManager +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.ready_queue.in_memory import InMemoryReadyQueue +from graphon.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from graphon.graph_events import NodeRunRetryEvent, NodeRunStartedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py index 25494dc647..dc0998caf1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py @@ -4,9 +4,9 @@ from __future__ import annotations import logging -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent +from graphon.graph_engine.event_management.event_manager import EventManager +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent class _FaultyLayer(GraphEngineLayer): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py index 73d59ea4e9..b030496eb1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, create_autospec -from dify_graph.graph import Edge, Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.graph_traversal.skip_propagator import SkipPropagator +from graphon.graph import Edge, Graph +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.graph_traversal.skip_propagator import SkipPropagator class TestSkipPropagator: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index fc8133f5e1..2fead1d719 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -7,13 +7,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import ( +from core.repositories.human_input_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRecipientEntity, HumanInputFormRepository, ) +from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now @@ -49,7 +49,7 @@ class _InMemoryFormEntity(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return self.token @property @@ -88,24 +88,24 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): self._form_counter = 0 self.created_params: list[FormCreateParams] = [] self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: self.created_params.append(params) self._form_counter += 1 form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" + token = f"token-{form_id}" entity = _InMemoryFormEntity( form_id=form_id, rendered=params.rendered_content, token=token, ) self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + self._forms_by_node_id[params.node_id] = entity return entity - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) # Convenience helpers for tests ------------------------------------- diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 9e7b3654b7..b642dc82fe 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -10,7 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes @pytest.fixture @@ -63,7 +63,7 @@ def mock_llm_node(): def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" from core.tools.entities.tool_entities import ToolProviderType - from dify_graph.nodes.tool.entities import ToolNodeData + from graphon.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from dify_graph.graph_events.node import NodeRunSucceededEvent - from dify_graph.node_events.base import NodeRunResult + from graphon.graph_events.node import NodeRunSucceededEvent + from graphon.node_events.base import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py index db32527849..7ff77c19c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -2,13 +2,13 @@ from __future__ import annotations import pytest -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import ( +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.layers.base import ( GraphEngineLayer, GraphEngineLayerNotInitializedError, ) -from dify_graph.graph_events import GraphEngineEvent +from graphon.graph_events import GraphEngineEvent from ..test_table_runner import WorkflowRunner diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd..80874e768a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,14 +1,27 @@ import threading from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.entities.commands import CommandType -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult +from core.model_manager import ModelInstance +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events.node import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + + +def _build_dify_context() -> DifyRunContext: + return DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) def _build_succeeded_event() -> NodeRunSucceededEvent: @@ -25,6 +38,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer() node = MagicMock() @@ -32,8 +50,8 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -41,7 +59,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -53,8 +71,8 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -62,7 +80,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -74,7 +92,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node._model_instance = object() result_event = _build_succeeded_event() @@ -91,7 +109,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -113,8 +131,8 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -141,7 +159,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -167,7 +185,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -175,5 +193,5 @@ def test_quota_precheck_passes_without_abort() -> None: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=raw_model_instance) layer.command_channel.send_command.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 478a2b592e..14ce55938d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -16,7 +16,7 @@ import pytest from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index 548c10ce8d..ab3a31f673 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -5,18 +5,18 @@ from __future__ import annotations import queue from unittest import mock -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_events import ( +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.event_management.event_handlers import EventHandler +from graphon.graph_engine.orchestration.dispatcher import Dispatcher +from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from graphon.graph_events import ( GraphNodeEventBase, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult +from graphon.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py index 7af6b26d87..1510c8e595 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py @@ -1,4 +1,4 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index fc0d22f739..5d0b37acc5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -7,7 +7,7 @@ for workflows containing nodes that require third-party services. import pytest -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -201,7 +201,7 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory @@ -308,8 +308,8 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.nodes.template_transform import TemplateTransformNode - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.nodes.template_transform import TemplateTransformNode + from graphon.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py index 30acbdaf3d..cefe3b8ac8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py @@ -1,4 +1,4 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 765c4deba3..01ac2d7a96 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,23 +3,23 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.entities.commands import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.entities.commands import ( AbortCommand, CommandType, PauseCommand, UpdateVariablesCommand, VariableUpdate, ) -from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.variables import IntegerVariable, StringVariable +from graphon.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import IntegerVariable, StringVariable def test_abort_command(): @@ -73,9 +73,8 @@ def test_abort_command(): config=GraphEngineConfig(), ) - # Send abort command before starting - abort_command = AbortCommand(reason="Test abort") - command_channel.send_command(abort_command) + # Queue an abort request before starting. + engine.request_abort("Test abort") # Run engine and collect events events = list(engine.run()) @@ -102,7 +101,7 @@ def test_redis_channel_serialization(): mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel + from graphon.graph_engine.command_channels.redis_channel import RedisChannel # Create channel with a specific key channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index 3a9a0b18bc..ba9c502452 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -7,7 +7,7 @@ This test suite validates the behavior of a workflow that: 3. Handles multiple answer nodes with different outputs """ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index 76bf179f33..3851480731 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -6,10 +6,10 @@ This test validates that: - When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) """ -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py index 778dad5952..3264ad1168 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -1,10 +1,10 @@ import queue -from datetime import datetime -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.orchestration.dispatcher import Dispatcher +from graphon.graph_events import NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from libs.datetime_utils import naive_utc_now class StubExecutionCoordinator: @@ -52,7 +52,7 @@ def test_dispatcher_drains_events_when_paused() -> None: id="exec-1", node_id="node-1", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), ) event_queue.put(event) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py index c87dc75b95..ada55f3dc5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py @@ -6,7 +6,7 @@ field is missing from the output configuration, ensuring backward compatibility with older workflow definitions. """ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 35406997ed..95a94110d2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock import pytest -from dify_graph.graph_engine.command_processing.command_processor import CommandProcessor -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_engine.worker_management.worker_pool import WorkerPool +from graphon.graph_engine.command_processing.command_processor import CommandProcessor +from graphon.graph_engine.domain.graph_execution import GraphExecution +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from graphon.graph_engine.worker_management.worker_pool import WorkerPool def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4e13177d2b..51ece26d49 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,11 +10,11 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType -from dify_graph.enums import ErrorStrategy -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.entities.base_node_data import DefaultValue, DefaultValueType +from graphon.enums import ErrorStrategy +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -455,7 +455,7 @@ def test_if_else_workflow_property_diverse_inputs(query_input): # Tests for the Layer system def test_layer_system_basic(): """Test basic layer functionality with DebugLoggingLayer.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer + from graphon.graph_engine.layers import DebugLoggingLayer runner = WorkflowRunner() @@ -495,7 +495,7 @@ def test_layer_system_basic(): def test_layer_chaining(): """Test chaining multiple layers.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer + from graphon.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer # Create a custom test layer class TestLayer(GraphEngineLayer): @@ -549,7 +549,7 @@ def test_layer_chaining(): def test_layer_error_handling(): """Test that layer errors don't crash the engine.""" - from dify_graph.graph_engine.layers import GraphEngineLayer + from graphon.graph_engine.layers import GraphEngineLayer # Create a layer that throws errors class FaultyLayer(GraphEngineLayer): @@ -591,7 +591,7 @@ def test_layer_error_handling(): def test_event_sequence_validation(): """Test the new event sequence validation feature.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() @@ -678,7 +678,7 @@ def test_event_sequence_validation(): def test_event_sequence_validation_with_table_tests(): """Test event sequence validation with table-driven tests.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 255784b77d..348ceb6788 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -6,13 +6,13 @@ import json from collections import deque from unittest.mock import MagicMock -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph_engine.domain import GraphExecution -from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator -from dify_graph.graph_engine.response_coordinator.path import Path -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.graph_events import NodeRunStreamChunkEvent -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState +from graphon.graph_engine.domain import GraphExecution +from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator +from graphon.graph_engine.response_coordinator.path import Path +from graphon.graph_engine.response_coordinator.session import ResponseSession +from graphon.graph_events import NodeRunStreamChunkEvent +from graphon.nodes.base.template import Template, TextSegment, VariableSegment class CustomGraphExecutionError(Exception): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index d54f0be190..a6417822d2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -1,26 +1,26 @@ import time from collections.abc import Mapping -from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeState -from dify_graph.graph import Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.llm.entities import ( +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import NodeState +from graphon.graph import Graph +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.ready_queue import InMemoryReadyQueue +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig @@ -29,7 +29,7 @@ from .test_mock_nodes import MockLLMNode def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 538f53c603..ca9a929591 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,8 +4,11 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.graph import Graph +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -14,25 +17,23 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +61,7 @@ def _build_branching_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -125,6 +126,7 @@ def _build_branching_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -246,7 +248,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -302,7 +304,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 36bba6deb6..c50aaafe2c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,8 +3,11 @@ import time from unittest import mock from unittest.mock import MagicMock -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.graph import Graph +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -13,25 +16,23 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -59,7 +60,7 @@ def _build_llm_human_llm_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," ), user_inputs={}, @@ -121,6 +122,7 @@ def _build_llm_human_llm_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -191,7 +193,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -260,7 +262,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index 8da179c15e..246df45d5f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,33 +1,33 @@ import time from unittest import mock -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_system_variables +from graphon.graph import Graph +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.nodes.if_else.if_else_node import IfElseNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.utils.condition.entities import Condition from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig @@ -44,7 +44,7 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr ) variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py index 733fd53bc8..821da46b76 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py @@ -5,7 +5,7 @@ This test validates the behavior of a loop containing an answer node inside the loop that may produce output errors. """ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, @@ -14,6 +14,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -50,6 +51,7 @@ def test_loop_contains_answer(): NodeRunLoopStartedEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 1 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, @@ -60,6 +62,7 @@ def test_loop_contains_answer(): NodeRunLoopNextEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 2 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py index 6ff2722f78..4a60c7769c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py @@ -1,4 +1,4 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, @@ -7,6 +7,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -44,12 +45,16 @@ def test_loop_with_tool(): NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, + NodeRunVariableUpdatedEvent, NodeRunSucceededEvent, NodeRunLoopNextEvent, # 2024 NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, + NodeRunVariableUpdatedEvent, NodeRunSucceededEvent, # LOOP END NodeRunLoopSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 93010eea54..76b2984a4b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -8,9 +8,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node from .test_mock_nodes import ( MockAgentNode, @@ -28,8 +28,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -111,7 +111,7 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, http_request_config=self._http_request_config, http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, + tool_file_manager_factory=self._bound_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) elif node_type in { diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 3e4247f33f..aff479104f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,8 +2,8 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -11,8 +11,8 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create a MockNodeFactory instance graph_init_params = GraphInitParams( @@ -63,8 +63,8 @@ def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode # Create mock config @@ -128,8 +128,8 @@ def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode # Create mock config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 454263bef9..971b9b2bbf 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,29 +11,29 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.nodes.agent import AgentNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.code import CodeNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode -from dify_graph.nodes.http_request import HttpRequestNode -from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.nodes.question_classifier import QuestionClassifierNode -from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) -from dify_graph.nodes.tool import ToolNode +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.code import CodeNode +from graphon.nodes.document_extractor import DocumentExtractorNode +from graphon.nodes.http_request import HttpRequestNode +from graphon.nodes.llm import LLMNode +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.parameter_extractor import ParameterExtractorNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.question_classifier import QuestionClassifierNode +from graphon.nodes.template_transform import TemplateTransformNode +from graphon.nodes.tool import ToolNode +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -66,20 +66,26 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + kwargs.setdefault("prompt_message_serializer", MagicMock(spec=PromptMessageSerializerProtocol)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - if isinstance(self, (LLMNode, QuestionClassifierNode)): - kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("llm_file_saver", MagicMock(spec=LLMFileSaver)) + + if isinstance(self, HttpRequestNode): + kwargs.setdefault("file_reference_factory", MagicMock(spec=FileReferenceFactoryProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) # Provide default tool_file_manager_factory for ToolNode subclasses - from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + from graphon.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): presentation_provider = MagicMock() @@ -596,8 +602,8 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): ) -from dify_graph.nodes.iteration import IterationNode -from dify_graph.nodes.loop import LoopNode +from graphon.nodes.iteration import IterationNode +from graphon.nodes.loop import LoopNode class MockIterationNode(MockNodeMixin, IterationNode): @@ -611,11 +617,11 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory @@ -656,7 +662,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): ) if not iteration_graph: - from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError + from graphon.nodes.iteration.exc import IterationGraphNotFoundError raise IterationGraphNotFoundError("iteration graph not found") @@ -683,11 +689,11 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index a8398e8f79..15f6f51398 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,9 +6,9 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.code.limits import CodeNodeLimits +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode @@ -40,8 +40,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_default_output(self): """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -60,7 +60,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -103,8 +103,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_custom_output(self): """Test that MockTemplateTransformNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -123,7 +123,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -167,8 +167,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_error_simulation(self): """Test that MockTemplateTransformNode can simulate errors.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -187,7 +187,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -228,9 +228,9 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_with_variables(self): """Test that MockTemplateTransformNode processes templates with variables.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from dify_graph.variables import StringVariable + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool + from graphon.variables import StringVariable # Create test parameters graph_init_params = GraphInitParams( @@ -249,7 +249,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -298,8 +298,8 @@ class TestMockCodeNode: def test_mock_code_node_default_output(self): """Test that MockCodeNode returns default output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -318,7 +318,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -364,8 +364,8 @@ class TestMockCodeNode: def test_mock_code_node_with_output_schema(self): """Test that MockCodeNode generates outputs based on schema.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -384,7 +384,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -438,8 +438,8 @@ class TestMockCodeNode: def test_mock_code_node_custom_output(self): """Test that MockCodeNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -458,7 +458,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -514,8 +514,8 @@ class TestMockNodeFactory: def test_code_and_template_nodes_mocked_by_default(self): """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -534,7 +534,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -559,8 +559,8 @@ class TestMockNodeFactory: def test_factory_creates_mock_template_transform_node(self): """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -579,7 +579,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -614,8 +614,8 @@ class TestMockNodeFactory: def test_factory_creates_mock_code_node(self): """Test that MockNodeFactory creates MockCodeNode for code type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -634,7 +634,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 5b35b3310a..cb5200f8dc 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -4,8 +4,8 @@ Simple test to validate the auto-mock system without external dependencies. import sys -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -98,8 +98,8 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory detection...") @@ -154,8 +154,8 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory registration...") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e681b39cc7..37b43bd374 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,32 +4,33 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.config import GraphEngineConfig +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -67,7 +68,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -103,7 +104,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -112,7 +113,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -159,6 +160,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -168,6 +170,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 60167c0441..59e54bd39a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,39 +4,40 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.config import GraphEngineConfig +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -59,7 +60,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -95,7 +96,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -115,7 +116,7 @@ class DelayedHumanInputNode(HumanInputNode): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -162,6 +163,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -171,6 +173,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), delay_seconds=0.2, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b954a4faac..1a43734462 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -15,20 +15,20 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_system_variables +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.llm.node import LLMNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -98,7 +98,7 @@ def test_parallel_streaming_workflow(): ) # Create variable pool with system variables - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=init_params.workflow_id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index 7328ce443f..bcf123ee80 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,40 +4,41 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.config import GraphEngineConfig +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +61,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -96,7 +97,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, form: HumanInputFormEntity) -> None: self._form = form - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: if node_id != "human_pause": return None return self._form @@ -107,7 +108,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -201,6 +202,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 15a7de3c52..79d3d5bcfe 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,38 +3,39 @@ import time from typing import Any from unittest.mock import MagicMock -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.graph import GraphRunStartedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.graph_events.graph import GraphRunStartedEvent +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -50,7 +51,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -65,7 +66,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -112,6 +113,7 @@ def _build_human_input_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index 9c84f42db6..146b728dc2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -12,9 +12,9 @@ import pytest import redis from core.app.apps.base_app_queue_manager import AppQueueManager -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from dify_graph.graph_engine.manager import GraphEngineManager +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from graphon.graph_engine.manager import GraphEngineManager class TestRedisStopIntegration: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py index cd9d56f683..62ca7a630e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py @@ -4,9 +4,9 @@ from __future__ import annotations import pytest -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.nodes.base.template import Template, TextSegment +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType +from graphon.graph_engine.response_coordinator.session import ResponseSession +from graphon.nodes.base.template import Template, TextSegment class DummyResponseNode: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 4f1741d4fb..a359a5fef9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -1,9 +1,10 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -33,6 +34,7 @@ def test_streaming_conversation_variables(): NodeRunSucceededEvent, # Variable Assigner node NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, # ANSWER node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ab8fb346b8..81d68ba2aa 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,29 +12,29 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from graphon.entities.graph_init_params import GraphInitParams +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ( +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -60,20 +60,28 @@ class _TableTestChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) if self._use_mock_factory: node_factory = MockNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, mock_config=self._mock_config, ) else: - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, + ) + graph_config = graph_init_params.graph_config child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) if not child_graph: raise ValueError("child graph not found") @@ -81,13 +89,11 @@ class _TableTestChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, ) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -206,14 +212,15 @@ class WorkflowRunner: call_depth=0, ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, ) - user_inputs = inputs if inputs is not None else {} + root_node_inputs = dict(inputs or {}) + root_node_inputs.setdefault("query", query) # Extract conversation variables from workflow config conversation_variables = [] @@ -242,11 +249,16 @@ class WorkflowRunner: ) conversation_variables.append(var) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs=user_inputs, - conversation_variables=conversation_variables, + root_node_id = get_default_root_node_id(graph_config) + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables, + conversation_variables=conversation_variables, + ), ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=root_node_inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -260,7 +272,7 @@ class WorkflowRunner: graph = Graph.init( graph_config=graph_config, node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), + root_node_id=root_node_id, ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index 7f26bc11a7..12aec6edf2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,6 +1,6 @@ -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py index f63e8ff4ce..2ad41037a9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py @@ -2,9 +2,9 @@ from unittest.mock import patch import pytest -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from .test_table_runner import TableTestRunner, WorkflowTestCase diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py new file mode 100644 index 0000000000..60cab77c0a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py @@ -0,0 +1,129 @@ +import time +import uuid +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import NodeRunVariableUpdatedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringVariable + +DEFAULT_NODE_ID = "node_id" + + +class CaptureVariableUpdateLayer(GraphEngineLayer): + def __init__(self) -> None: + super().__init__() + self.events: list[NodeRunVariableUpdatedEvent] = [] + self.observed_values: list[object | None] = [] + + def on_graph_start(self) -> None: + pass + + def on_event(self, event) -> None: + if not isinstance(event, NodeRunVariableUpdatedEvent): + return + + current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) + self.events.append(event) + self.observed_values.append(None if current_value is None else current_value.value) + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_graph_engine_applies_variable_updates_before_notifying_layers(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "Start"}, "id": "start"}, + { + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "over-write", + "input_variable_selector": ["node_id", "test_string_variable"], + }, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + workflow_id="1", + graph_config=graph_config, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, + call_depth=0, + ) + + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), + conversation_variables=[ + StringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value="the first value", + ) + ], + ), + ) + variable_pool.add( + [DEFAULT_NODE_ID, "test_string_variable"], + StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ), + ) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + + engine = GraphEngine( + workflow_id="workflow-id", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + ) + capture_layer = CaptureVariableUpdateLayer() + engine.layer(capture_layer) + + events = list(engine.run()) + + update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] + assert len(update_events) == 1 + assert update_events[0].variable.value == "the second value" + + current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) + assert current_value is not None + assert current_value.value == "the second value" + + assert len(capture_layer.events) == 1 + assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py index bc00b49fba..85132674b8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py @@ -4,15 +4,16 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import MagicMock, patch -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.graph_engine.worker import Worker -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.ready_queue import InMemoryReadyQueue +from graphon.graph_engine.worker import Worker +from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) + mock_datetime = mocker.patch("graphon.graph_engine.worker.datetime") + mock_datetime.now.return_value = fixed_time.replace(tzinfo=UTC) worker = Worker( ready_queue=InMemoryReadyQueue(), @@ -75,7 +76,8 @@ def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: worker._event_queue.put.side_effect = put_side_effect - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + with patch("graphon.graph_engine.worker.datetime") as mock_datetime: + mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) worker.run() fallback_event = captured_events[-1] @@ -135,7 +137,8 @@ def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_time worker._event_queue.put.side_effect = put_side_effect - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + with patch("graphon.graph_engine.worker.datetime") as mock_datetime: + mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) worker.run() fallback_event = captured_events[-1] diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py new file mode 100644 index 0000000000..1f4509af9a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -0,0 +1,33 @@ +from unittest.mock import patch + +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from graphon.enums import BuiltinNodeTypes + + +def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: + messages = iter(()) + transformer = AgentMessageTransformer() + + with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", return_value=iter(())) as transform: + result = list( + transformer.transform( + messages=messages, + tool_info={}, + parameters_for_log={}, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + node_type=BuiltinNodeTypes.AGENT, + node_id="node-id", + node_execution_id="execution-id", + ) + ) + + assert len(result) == 2 + transform.assert_called_once_with( + messages=messages, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py new file mode 100644 index 0000000000..c86de7f6e6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -0,0 +1,49 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from graphon.model_runtime.entities.model_entities import ModelType + + +def test_fetch_model_reuses_single_model_assembly(): + provider_configuration = SimpleNamespace( + get_current_credentials=Mock(return_value={"api_key": "x"}), + provider=SimpleNamespace(provider="openai"), + ) + model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema")) + provider_model_bundle = SimpleNamespace( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + model_instance = Mock() + assembly = SimpleNamespace( + provider_manager=Mock(), + model_manager=Mock(), + ) + assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle + assembly.model_manager.get_model_instance.return_value = model_instance + + with patch( + "core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model( + tenant_id="tenant-1", + user_id="user-1", + value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"}, + ) + + assert resolved_instance is model_instance + assert resolved_schema == "schema" + mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + assembly.provider_manager.get_provider_model_bundle.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + ) + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fd563d1be2..9c0ad25b58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -4,12 +4,12 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.answer.answer_node import AnswerNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -48,7 +48,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 81d3f5be9c..ec4cef1955 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,9 +1,9 @@ import pytest from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 972a945ca0..ef0df55995 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -2,15 +2,15 @@ import types from collections.abc import Mapping from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) -from dify_graph.nodes.variable_assigner.v1.node import ( +from graphon.nodes.variable_assigner.v1.node import ( VariableAssignerNode as VariableAssignerV1, ) -from dify_graph.nodes.variable_assigner.v2.node import ( +from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 784e08edd2..ce0c9b79c6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,13 +1,13 @@ from configs import dify_config -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.exc import ( +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.types import SegmentType +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.variables.types import SegmentType CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py index de7ed0815e..20fe2c1a74 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.variables.types import SegmentType +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.variables.types import SegmentType class TestCodeNodeDataOutput: diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 859115ceb3..1d76067ec2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,7 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py index cd822a6f89..f1a48f49b9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.http_request import build_http_request_config +from graphon.nodes.http_request import build_http_request_config def test_build_http_request_config_uses_literal_defaults(): diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index fec6ad90eb..88895608d9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, PropertyMock, patch import httpx import pytest -from dify_graph.nodes.http_request.entities import Response +from graphon.nodes.http_request.entities import Response @pytest.fixture @@ -104,7 +104,7 @@ def test_mimetype_based_detection(mock_response, content_type, expected_main_typ mock_response.headers = {"content-type": content_type} type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - with patch("dify_graph.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: + with patch("graphon.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: # Mock the return value based on expected_main_type if expected_main_type: mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index cea7195417..be7cc073db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -2,19 +2,19 @@ import pytest from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import ( +from core.workflow.system_variables import default_system_variables +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeConfig, HttpRequestNodeData, ) -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout -from dify_graph.nodes.http_request.exc import AuthorizationConfigError -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout +from graphon.nodes.http_request.exc import AuthorizationConfigError +from graphon.nodes.http_request.executor import Executor +from graphon.runtime import VariablePool HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -30,7 +30,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -86,7 +86,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -144,7 +144,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -231,7 +231,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -537,7 +537,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -584,7 +584,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -625,7 +625,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["node", "count"], 42) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 5e34bf1d94..a3cadc0681 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -7,12 +7,13 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -109,7 +110,7 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=time.perf_counter(), ) return HttpRequestNode( @@ -121,6 +122,7 @@ def _build_http_node( http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), ) @@ -161,7 +163,7 @@ def test_run_passes_node_data_ssl_verify_to_executor(monkeypatch: pytest.MonkeyP ) ) - monkeypatch.setattr("dify_graph.nodes.http_request.node.Executor", FakeExecutor) + monkeypatch.setattr("graphon.nodes.http_request.node.Executor", FakeExecutor) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d52dfa2a65..1d6a4da7c4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,5 +1,5 @@ -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from dify_graph.runtime import VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients +from graphon.runtime import VariablePool def test_render_body_template_replaces_variable_values(): diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 55aa62a1c0..5f28a07606 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -8,35 +8,38 @@ from unittest.mock import MagicMock import pytest from pydantic import ValidationError -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.node_events import PauseRequestedEvent -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.human_input.entities import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import HumanInputFormRepository +from core.workflow.human_input_compat import ( + DeliveryMethodType, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, + EmailRecipientType, ExternalRecipient, - FormInput, - FormInputDefault, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, _WebAppDeliveryConfig, ) -from dify_graph.nodes.human_input.enums import ( +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( + FormInput, + FormInputDefault, + HumanInputNodeData, + UserAction, +) +from graphon.nodes.human_input.enums import ( ButtonStyle, - DeliveryMethodType, - EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit, ) -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository @@ -54,9 +57,9 @@ class TestDeliveryMethod: def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="test-user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -193,7 +196,7 @@ class TestHumanInputNodeData: EmailDeliveryMethod( enabled=False, # Disabled method should be fine config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + subject="Hi there", body="", recipients=EmailRecipients(include_bound_group=True) ), ), ] @@ -212,7 +215,7 @@ class TestHumanInputNodeData: assert node_data.title == "Test Node" assert node_data.desc is None - assert node_data.delivery_methods == [] + assert node_data.model_dump().get("delivery_methods") is None assert node_data.form_content == "" assert node_data.inputs == [] assert node_data.user_actions == [] @@ -261,10 +264,10 @@ class TestRecipients: def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123") assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" @@ -273,37 +276,46 @@ class TestRecipients: assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" + def test_email_recipients_bound_group(self): + """Test email recipients with the bound group enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + include_bound_group=True, + items=[MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123")], ) - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + assert recipients.include_bound_group is True + assert len(recipients.items) == 1 # Items are preserved even when include_bound_group is True def test_email_recipients_specific_users(self): """Test email recipients with specific users.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], ) - assert recipients.whole_workspace is False + assert recipients.include_bound_group is False assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" + assert recipients.items[0].reference_id == "user-123" assert recipients.items[1].email == "external@example.com" + def test_legacy_recipient_keys_are_rejected(self): + with pytest.raises(ValidationError): + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + recipients = EmailRecipients(whole_workspace=True, items=[]) + assert recipients.include_bound_group is True + assert recipients.items == [] + class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -353,17 +365,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -378,7 +392,7 @@ class TestHumanInputNodeVariableResolution: def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -416,28 +430,96 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-2", rendered_content="Provide your name", - web_app_token="console-token", + submission_token="console-token", recipients=[SimpleNamespace(token="recipient-token")], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() pause_event = next(run_result) assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" + assert not hasattr(pause_event.reason, "form_token") + + def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): + variable_pool = VariablePool( + system_variables=build_system_variables( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-4", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "end-user-1", + "user_from": "end-user", + "invoke_from": "web-app", + } + }, + call_depth=0, + ) + + config = { + "id": "human", + "data": { + "type": "human-input", + "title": "Human Input", + "form_content": "Provide your name", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "delivery_methods": [{"enabled": True, "type": "webapp", "config": {}}], + }, + } + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-4", + rendered_content="Provide your name", + submission_token="token", + recipients=[], + submitted=False, + ) + + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + runtime=runtime, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + params = mock_repo.create_form.call_args.args[0] + assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user-123", app_id="app", workflow_id="workflow", @@ -472,7 +554,7 @@ class TestHumanInputNodeVariableResolution: enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], ), subject="Subject", @@ -489,17 +571,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-3", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -511,11 +595,11 @@ class TestHumanInputNodeVariableResolution: method = params.delivery_methods[0] assert isinstance(method, EmailDeliveryMethod) assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False + assert method.config.recipients.include_bound_group is False assert len(method.config.recipients.items) == 1 recipient = method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" class TestValidation: @@ -552,7 +636,7 @@ class TestHumanInputNodeRenderedContent: def test_replaces_outputs_placeholders_after_submission(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -591,12 +675,14 @@ class TestHumanInputNodeRenderedContent: config = {"id": "human", "data": node_data.model_dump()} form_repository = InMemoryHumanInputFormRepository() + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=form_repository, + runtime=runtime, ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index b0ed47158d..fc4497f010 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,18 +1,19 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_init_params import GraphInitParams +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now @@ -25,7 +26,7 @@ class _FakeFormRepository: def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -85,11 +86,12 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -149,6 +151,7 @@ def _build_timeout_node() -> HumanInputNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py index 93c199514e..8cc91bdb54 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.iteration.entities import ( +from graphon.nodes.iteration.entities import ( ErrorHandleMode, IterationNodeData, IterationStartNodeData, diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index fdf5f4d1f8..58b82aa893 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,7 +1,7 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.exc import ( +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.nodes.iteration.exc import ( InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -9,7 +9,7 @@ from dify_graph.nodes.iteration.exc import ( IteratorVariableNotFoundError, StartNodeIdNotFoundError, ) -from dify_graph.nodes.iteration.iteration_node import IterationNode +from graphon.nodes.iteration.iteration_node import IterationNode class TestIterationNodeExceptions: diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py new file mode 100644 index 0000000000..4c3ad85fcd --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py @@ -0,0 +1,201 @@ +from threading import Event +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph_events import GraphRunAbortedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.nodes.iteration.exc import ChildGraphAbortedError +from graphon.nodes.iteration.iteration_node import IterationNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage + + +class _AbortOnRequestGraphEngine: + def __init__(self, *, index: int, total_tokens: int) -> None: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + self.started = Event() + self.abort_requested = Event() + self.finished = Event() + self.abort_reason: str | None = None + self.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ) + + def request_abort(self, reason: str | None = None) -> None: + self.abort_reason = reason + self.abort_requested.set() + + def run(self): + self.started.set() + assert self.abort_requested.wait(1), "parallel sibling never received an abort request" + self.finished.set() + yield GraphRunAbortedEvent(reason=self.abort_reason) + + +def _build_immediate_abort_graph_engine( + *, + index: int, + total_tokens: int, + wait_before_abort: Event | None = None, +) -> SimpleNamespace: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + started = Event() + finished = Event() + + def run(): + started.set() + if wait_before_abort is not None: + assert wait_before_abort.wait(1), "parallel sibling never started" + finished.set() + yield GraphRunAbortedEvent(reason="quota exceeded") + + return SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ), + run=run, + request_abort=lambda reason=None: None, + started=started, + finished=finished, + ) + + +def _build_iteration_node( + *, + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, + is_parallel: bool = False, +) -> IterationNode: + node = IterationNode.__new__(IterationNode) + node._node_id = "iteration-node" + node._node_data = IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id="child-start", + is_parallel=is_parallel, + parallel_nums=2, + error_handle_mode=error_handle_mode, + ) + + variable_pool = build_test_variable_pool() + variable_pool.add(["start", "items"], ["first", "second"]) + node.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=LLMUsage.empty_usage(), + ) + return node + + +def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: + node = _build_iteration_node() + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], 0) + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): + list( + node._run_single_iter( + variable_pool=variable_pool, + outputs=[], + graph_engine=graph_engine, + ) + ) + + +def test_iteration_run_fails_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + node._create_graph_engine.assert_called_once() + node._run_single_iter.assert_called_once() + + +def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=_usage_with_tokens(7), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 + + +@pytest.mark.parametrize( + "error_handle_mode", + [ + ErrorHandleMode.CONTINUE_ON_ERROR, + ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, + ], +) +def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( + error_handle_mode: ErrorHandleMode, +) -> None: + node = _build_iteration_node( + error_handle_mode=error_handle_mode, + is_parallel=True, + ) + blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) + aborting_engine = _build_immediate_abort_graph_engine( + index=0, + total_tokens=3, + wait_before_abort=blocking_engine.started, + ) + node._create_graph_engine = MagicMock( + side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] + ) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + assert events[-1].node_run_result.llm_usage.total_tokens == 8 + assert node.graph_runtime_state.llm_usage.total_tokens == 8 + assert blocking_engine.started.is_set() + assert blocking_engine.abort_requested.is_set() + assert blocking_engine.finished.is_set() + assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 2eb4feef5f..82cc734274 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -1,18 +1,18 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any import pytest -from dify_graph.entities import GraphInitParams -from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError -from dify_graph.nodes.iteration.iteration_node import IterationNode -from dify_graph.runtime import ( +from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.nodes.iteration.exc import IterationGraphNotFoundError +from graphon.nodes.iteration.iteration_node import IterationNode +from graphon.runtime import ( ChildEngineBuilderNotConfiguredError, ChildGraphNotFoundError, GraphRuntimeState, VariablePool, ) -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -22,17 +22,16 @@ class _MissingGraphBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> object: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -69,8 +68,6 @@ def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing runtime_state.create_child_engine( workflow_id="workflow", graph_init_params=graph_init_params, - graph_runtime_state=_build_runtime_state(), - graph_config={}, root_node_id="root", ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py index 8660449032..41d7c3193d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py @@ -1,14 +1,13 @@ import time -from contextlib import nullcontext from datetime import UTC, datetime import pytest -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.iteration_node import IterationNode +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.nodes.iteration.iteration_node import IterationNode def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: @@ -21,11 +20,17 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: parallel_nums=2, error_handle_mode=ErrorHandleMode.TERMINATED, ) - node._capture_execution_context = lambda: nullcontext() - node._sync_conversation_variables_from_snapshot = lambda snapshot: None node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): + def fake_execute_tracked_iteration_parallel( + *, + index: int, + item: object, + started_child_engines: dict[int, object], + started_child_engines_lock: object, + ): + _ = started_child_engines + _ = started_child_engines_lock return ( 0.1 + (index * 0.1), [ @@ -37,11 +42,10 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: ), ], f"output-{item}", - {}, LLMUsage.empty_usage(), ) - node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel + node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel outputs: list[object] = [] iter_run_map: dict[str, float] = {} diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab..a6fca1bfb4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -5,6 +5,7 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -14,10 +15,10 @@ from core.workflow.nodes.knowledge_index.protocols import ( PreviewItem, SummaryIndexServiceProtocol, ) -from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +41,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -78,7 +79,7 @@ def sample_node_data(): type="knowledge-index", chunk_structure="general_structure", index_chunk_variable_selector=["start", "chunks"], - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, summary_index_setting=None, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 99997db6b2..45e8ae7d20 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -16,11 +16,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringSegment +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -43,7 +43,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -157,7 +157,7 @@ class TestKnowledgeRetrievalNode: ): """Test _run with query variable in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" query_selector = ["start", "query"] @@ -441,7 +441,7 @@ class TestFetchDatasetRetriever: ): """Test _fetch_dataset_retriever in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" variables = {"query": query} diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index d71e0921c1..eca34f05be 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock import pytest -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.list_operator.node import ListOperatorNode -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.node import ListOperatorNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import ArrayNumberSegment, ArrayStringSegment class TestListOperatorNode: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index b0f0fd428b..4f9ba0194a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -6,17 +6,14 @@ from unittest.mock import MagicMock import httpx import pytest -from core.helper import ssrf_proxy -from core.tools import signature -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import FileTransferMethod, FileType, models -from dify_graph.nodes.llm.file_saver import ( +from graphon.file import FileTransferMethod, FileType +from graphon.nodes.llm.file_saver import ( FileSaverImpl, _extract_content_type_and_extension, _get_extension, _validate_extension_override, ) -from models import ToolFile +from graphon.nodes.protocols import ToolFileManagerProtocol _PNG_DATA = b"\x89PNG\r\n\x1a\n" @@ -27,58 +24,45 @@ def _gen_id(): class TestFileSaverImpl: def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - user_id = _gen_id() - tenant_id = _gen_id() file_type = FileType.IMAGE mime_type = "image/png" - mock_signed_url = "https://example.com/image.png" - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) + mock_tool_file = MagicMock() mock_tool_file.id = _gen_id() - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - + mock_tool_file.name = f"{_gen_id()}.png" + mock_tool_file.file_key = "test-file-key" + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) - # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. - mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) - # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. - monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) - mocked_sign_file.return_value = mock_signed_url + file_reference = MagicMock() + file_reference_factory = MagicMock() + file_reference_factory.build_from_mapping.return_value = file_reference http_client = MagicMock() - storage_file_manager = FileSaverImpl( - user_id=user_id, - tenant_id=tenant_id, + file_saver = FileSaverImpl( + tool_file_manager=mocked_tool_file_manager, + file_reference_factory=file_reference_factory, http_client=http_client, ) - file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file.tenant_id == tenant_id - assert file.type == file_type - assert file.transfer_method == FileTransferMethod.TOOL_FILE - assert file.extension == ".png" - assert file.mime_type == mime_type - assert file.size == len(_PNG_DATA) - assert file.related_id == mock_tool_file.id - - assert file.generate_url() == mock_signed_url + file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type) + assert file is file_reference mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) + file_reference_factory.build_from_mapping.assert_called_once_with( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": mock_tool_file.name, + "extension": ".png", + "mime_type": mime_type, + "size": len(_PNG_DATA), + "tool_file_id": mock_tool_file.id, + "related_id": mock_tool_file.id, + "storage_key": mock_tool_file.file_key, + } + ) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" @@ -91,8 +75,8 @@ class TestFileSaverImpl: http_client.get.return_value = mock_response file_saver = FileSaverImpl( - user_id=_gen_id(), - tenant_id=_gen_id(), + tool_file_manager=MagicMock(), + file_reference_factory=MagicMock(), http_client=http_client, ) @@ -104,8 +88,6 @@ class TestFileSaverImpl: def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" mime_type = "image/png" - user_id = _gen_id() - tenant_id = _gen_id() mock_request = httpx.Request("GET", _TEST_URL) mock_response = httpx.Response( @@ -117,21 +99,13 @@ class TestFileSaverImpl: http_client = MagicMock() http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), + file_saver = FileSaverImpl( + tool_file_manager=MagicMock(), + file_reference_factory=MagicMock(), + http_client=http_client, ) - mock_tool_file.id = _gen_id() - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) + expected_file = MagicMock() + mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file) monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) @@ -141,7 +115,7 @@ class TestFileSaverImpl: FileType.IMAGE, extension_override=".png", ) - assert file == mock_tool_file + assert file is expected_file def test_validate_extension_override(): diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 618a498659..dfc982f49c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -3,12 +3,94 @@ from unittest import mock import pytest from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent -from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage -from dify_graph.nodes.llm.exc import NoPromptFoundError -from dify_graph.runtime import VariablePool +from graphon.file import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from graphon.nodes.llm.exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) +from graphon.runtime import VariablePool +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label={"en_US": "GPT-3.5 Turbo"}, + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_model_instance(*, model_schema: AIModelEntity | None = None) -> mock.MagicMock: + model_instance = mock.MagicMock(spec=ModelInstance) + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.get_model_schema.return_value = model_schema or _build_model_schema(features=[]) + model_instance.get_llm_num_tokens.return_value = 0 + return model_instance + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + +@pytest.fixture +def variable_pool() -> VariablePool: + pool = VariablePool.empty() + pool.add(["node1", "output"], "resolved_value") + pool.add(["node2", "text"], "hello world") + pool.add(["start", "user_input"], "dynamic_param") + return pool def _fetch_prompt_messages_with_mocked_content(content): @@ -24,15 +106,15 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( - "dify_graph.nodes.llm.llm_utils.fetch_model_schema", + "graphon.nodes.llm.llm_utils.fetch_model_schema", return_value=mock.MagicMock(features=[]), ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_list_messages", + "graphon.nodes.llm.llm_utils.handle_list_messages", return_value=[SystemPromptMessage(content=content)], ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", + "graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[], ), ): @@ -53,6 +135,159 @@ def _fetch_prompt_messages_with_mocked_content(content): ) +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + """Non-string params pass through; string params with variables get resolved.""" + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + """When a referenced variable doesn't exist in the pool, convert_template + falls back to the raw selector path (e.g. 'nonexistent.var').""" + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): with pytest.raises(NoPromptFoundError): _fetch_prompt_messages_with_mocked_content( @@ -104,3 +339,700 @@ def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_ ] ) ] + + +def test_fetch_model_schema_raises_when_model_schema_is_missing(): + model_instance = _build_model_instance() + model_instance.get_model_schema.return_value = None + + with pytest.raises(ValueError, match="Model schema not found for gpt-3.5-turbo"): + llm_utils.fetch_model_schema(model_instance=model_instance) + + +def test_fetch_files_supports_known_segments_and_rejects_invalid_types(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + variable_pool = VariablePool.empty() + variable_pool.add(["input", "file"], file) + variable_pool.add(["input", "files"], ArrayFileSegment(value=[file])) + variable_pool.add(["input", "none"], NoneSegment()) + variable_pool.add(["input", "empty"], ArrayAnySegment(value=[])) + variable_pool.add(["input", "invalid"], {"a": 1}) + + assert llm_utils.fetch_files(variable_pool, ["input", "file"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "files"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "none"]) == [] + assert llm_utils.fetch_files(variable_pool, ["input", "empty"]) == [] + + with pytest.raises(InvalidVariableTypeError, match="Invalid variable type"): + llm_utils.fetch_files(variable_pool, ["input", "invalid"]) + + +def test_fetch_files_returns_empty_for_missing_variable(): + assert llm_utils.fetch_files(VariablePool.empty(), ["input", "missing"]) == [] + + +def test_convert_history_messages_to_text_skips_system_messages_and_formats_images(): + history_text = llm_utils.convert_history_messages_to_text( + history_messages=[ + SystemPromptMessage(content="skip"), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="Answer"), + ], + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert history_text == "Human: Question\n[image]\nAssistant: Answer" + + +def test_fetch_memory_text_uses_prompt_memory_interface(): + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=321, + message_limit=2, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert memory_text == "Human: Question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_handle_list_messages_renders_jinja2_messages(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=renderer, + ) + + assert prompt_messages == [SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")])] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_handle_list_messages_splits_text_and_file_content(): + variable_pool = VariablePool.empty() + image_file = _build_image_file( + file_id="image-file", + related_id="image-related", + remote_url="https://example.com/file.png", + ) + variable_pool.add(["input", "image"], image_file) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ) as mock_to_prompt: + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="Analyze {{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Analyze ")]), + UserPromptMessage( + content=[ + ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + ] + ), + ] + mock_to_prompt.assert_called_once() + + +def test_handle_list_messages_supports_array_file_segments(): + variable_pool = VariablePool.empty() + first_file = _build_image_file(file_id="first", related_id="first-related", remote_url="https://example.com/1.png") + second_file = _build_image_file( + file_id="second", + related_id="second-related", + remote_url="https://example.com/2.png", + ) + variable_pool.add(["input", "images"], ArrayFileSegment(value=[first_file, second_file])) + + first_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/1.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + second_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/2.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=[first_prompt, second_prompt], + ): + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [UserPromptMessage(content=[first_prompt, second_prompt])] + + +def test_render_jinja2_message_handles_empty_template_success_and_missing_renderer(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + assert ( + llm_utils.render_jinja2_message( + template="", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + == "" + ) + + with pytest.raises(ValueError, match="template_renderer is required"): + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + assert ( + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=renderer, + ) + == "Hello Dify" + ) + + +def test_handle_completion_template_supports_basic_and_jinja2_templates(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + basic_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize {{#context#}}", + edition_type="basic", + ), + context="the docs", + jinja2_variables=[], + variable_pool=variable_pool, + ) + jinja_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="Hello {{ name }}", + edition_type="jinja2", + ), + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + template_renderer=renderer, + ) + + assert basic_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Summarize the docs")]), + ] + assert jinja_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + + +def test_combine_message_content_with_role_handles_all_supported_roles(): + contents = [TextPromptMessageContent(data="hello")] + + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.USER) == ( + UserPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.ASSISTANT) == ( + AssistantPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.SYSTEM) == ( + SystemPromptMessage(content=contents) + ) + + with pytest.raises(NotImplementedError, match="Role custom is not supported"): + llm_utils.combine_message_content_with_role(contents=contents, role="custom") # type: ignore[arg-type] + + +def test_calculate_rest_token_uses_context_size_and_template_alias(): + model_instance = _build_model_instance( + model_schema=_build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="output_limit", + use_template="max_tokens", + label={"en_US": "Output Limit"}, + type=ParameterType.INT, + ) + ], + ) + ) + model_instance.parameters = {"max_tokens": 512} + model_instance.get_llm_num_tokens.return_value = 256 + + assert ( + llm_utils.calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 3328 + ) + + +def test_handle_memory_chat_mode_returns_empty_without_memory_and_uses_window_when_present(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + assert ( + llm_utils.handle_memory_chat_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == [] + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=123) as mock_rest: + messages = llm_utils.handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + assert messages == [UserPromptMessage(content="Question")] + mock_rest.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=123, message_limit=2) + + +def test_handle_memory_completion_mode_validates_role_prefix_and_formats_history(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="Question"), + AssistantPromptMessage(content="Answer"), + ] + + assert ( + llm_utils.handle_memory_completion_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == "" + ) + + with ( + mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456), + pytest.raises(MemoryRolePrefixRequiredError, match="Memory role prefix is required"), + ): + llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456): + history_text = llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ), + model_instance=model_instance, + ) + + assert history_text == "Human: Question\nAssistant: Answer" + memory.get_history_prompt_messages.assert_called_with(max_token_limit=456, message_limit=None) + + +def test_append_file_prompts_merges_with_existing_user_content_or_appends_new_message(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + file_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + prompt_messages = [UserPromptMessage(content=[TextPromptMessageContent(data="Question")])] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[file_prompt, TextPromptMessageContent(data="Question")]), + ] + + prompt_messages = [SystemPromptMessage(content="System prompt")] + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages[-1] == UserPromptMessage(content=[file_prompt]) + + +def test_fetch_prompt_messages_chat_mode_includes_query_memory_and_supported_files(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.VISION])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history")] + sys_file = _build_image_file(file_id="sys", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + file_prompts = [ + ImagePromptMessageContent( + format="png", + url="https://example.com/sys.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + format="png", + url="https://example.com/context.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=file_prompts, + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history") + assert prompt_messages[2] == UserPromptMessage( + content=[ + file_prompts[1], + file_prompts[0], + TextPromptMessageContent(data="current question"), + ] + ) + + +def test_fetch_prompt_messages_completion_mode_updates_list_content_with_histories_and_query(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="another question"), + AssistantPromptMessage(content="another answer"), + ] + + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header", + edition_type="basic", + ), + stop=None, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [ + UserPromptMessage(content="latest question\nHuman: another question\nAssistant: another answer\nPrompt header") + ] + + +def test_fetch_prompt_messages_filters_content_unsupported_by_model_features(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.DOCUMENT])) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_list_messages", + return_value=[ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + ], + ), + mock.patch("graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[]), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=("END",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("END",) + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_completion_mode_supports_string_content_and_invalid_template_type(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix #histories# and #sys.query#")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=("HALT",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [UserPromptMessage(content="Prefix history text and latest question")] + + with pytest.raises(TemplateTypeNotSupportError): + llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=object(), # type: ignore[arg-type] + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + invalid_prompt = mock.MagicMock() + invalid_prompt.content = object() + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[invalid_prompt], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + pytest.raises(ValueError, match="Invalid prompt content type"), + ): + llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix only")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content="history text\nPrefix only")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fc96088af1..a2fbc50392 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,40 +5,80 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities import GraphInitParams -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( +from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, + SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + PromptConfig, VisionConfig, VisionConfigOptions, ) -from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from graphon.nodes.llm.exc import ( + InvalidContextStructureError, + LLMNodeError, + NoPromptFoundError, + VariableNotFoundError, +) +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import ( + LLMNode, + _calculate_rest_token, + _handle_completion_template, + _handle_memory_chat_mode, + _handle_memory_completion_mode, + _render_jinja2_message, +) +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -55,6 +95,62 @@ class MockTokenBufferMemory: return self.history_messages +def _build_prepared_llm_mock() -> mock.MagicMock: + model_instance = mock.MagicMock() + model_instance.provider = "openai" + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.stop = () + model_instance.get_llm_num_tokens.return_value = 0 + model_instance.get_model_schema.return_value = AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + model_instance.is_structured_output_parse_error.return_value = False + return model_instance + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( @@ -91,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) return GraphRuntimeState( @@ -107,7 +203,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -120,9 +216,9 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node @@ -132,28 +228,31 @@ def llm_node( def model_config(monkeypatch): from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass - def mock_plugin_model_providers(_self): - providers = MockModelClass().fetch_model_providers("test") - for provider in providers: - provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + def mock_model_providers(_self): + providers = [] + for provider in MockModelClass().fetch_model_providers("test"): + provider_schema = provider.declaration.model_copy(deep=True) + provider_schema.provider = f"{provider.plugin_id}/{provider.provider}" + provider_schema.provider_name = provider.provider + providers.append(provider_schema) return providers monkeypatch.setattr( ModelProviderFactory, - "get_plugin_model_providers", - mock_plugin_model_providers, + "get_model_providers", + mock_model_providers, ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(tenant_id="test") - provider_instance = model_provider_factory.get_plugin_model_provider("openai") + model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + provider_instance = model_provider_factory.get_model_provider("openai") model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance.declaration, + provider=provider_instance, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -181,13 +280,18 @@ def model_config(monkeypatch): ) -def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): +def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_config: ModelConfigWithCredentialsEntity): mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) - mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_model_factory = mock.MagicMock(spec=DifyModelFactory) provider_model_bundle = model_config.provider_model_bundle model_type_instance = provider_model_bundle.model_type_instance provider_model = mock.MagicMock() + completion_params = { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } model_instance = mock.MagicMock( model_type_instance=model_type_instance, @@ -208,12 +312,36 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): - fetch_model_config( - node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + hydrated_model_instance, model_config_with_credentials = fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params=completion_params, + ), credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, ) + assert hydrated_model_instance is model_instance + assert hydrated_model_instance.provider == "openai" + assert hydrated_model_instance.model_name == "gpt-3.5-turbo" + assert hydrated_model_instance.credentials == {"api_key": "test"} + assert hydrated_model_instance.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert hydrated_model_instance.stop == ("Observation:", "Human:") + assert model_config_with_credentials.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert model_config_with_credentials.stop == ["Observation:", "Human:"] + assert completion_params == { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") provider_model.raise_for_status.assert_called_once() @@ -230,12 +358,20 @@ def test_dify_model_access_adapters_call_managers(): mock_provider_configuration.get_provider_model.return_value = mock_provider_model mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} - credentials_provider = DifyCredentialsProvider( + run_context = DifyRunContext( tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + credentials_provider = DifyCredentialsProvider( + run_context=run_context, provider_manager=mock_provider_manager, ) model_factory = DifyModelFactory( - tenant_id="tenant", + run_context=run_context, model_manager=mock_model_manager, ) @@ -255,18 +391,18 @@ def test_dify_model_access_adapters_call_managers(): model="gpt-3.5-turbo", ) mock_provider_model.raise_for_status.assert_called_once() - mock_model_manager.get_model_instance.assert_called_once_with( - tenant_id="tenant", - provider="openai", - model_type=ModelType.LLM, - model="gpt-3.5-turbo", - ) + mock_model_manager.get_model_instance.assert_called_once() + assert mock_model_manager.get_model_instance.call_args.kwargs == { + "tenant_id": "tenant", + "provider": "openai", + "model_type": ModelType.LLM, + "model": "gpt-3.5-turbo", + } def test_fetch_files_with_file_segment(): file = File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -284,7 +420,6 @@ def test_fetch_files_with_array_file_segment(): files = [ File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -293,7 +428,6 @@ def test_fetch_files_with_array_file_segment(): ), File( id="2", - tenant_id="test", type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -343,7 +477,6 @@ def test_fetch_files_with_non_existent_variable(): # files = [ # File( # id="1", -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -448,7 +581,6 @@ def test_fetch_files_with_non_existent_variable(): # sys_query=fake_query, # sys_files=[ # File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -524,7 +656,6 @@ def test_fetch_files_with_non_existent_variable(): # + [UserPromptMessage(content=fake_query)], # file_variables={ # "input.image": File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -569,7 +700,7 @@ def test_fetch_files_with_non_existent_variable(): def test_handle_list_messages_basic(llm_node): messages = [ LLMNodeChatModelMessage( - text="Hello, {#context#}", + text="Hello, {{#context#}}", role=PromptMessageRole.USER, edition_type="basic", ) @@ -592,32 +723,414 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] -def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): - llm_node._template_renderer.render_jinja2.return_value = "Hello, world" +def test_handle_list_messages_replaces_double_brace_context_placeholder(llm_node): messages = [ LLMNodeChatModelMessage( - text="", - jinja2_text="Hello, {{ name }}", - role=PromptMessageRole.USER, - edition_type="jinja2", + text="Answer user's question with the following context:\n\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", ) ] + context = "## Overview\nSends a JSON request." result = llm_node.handle_list_messages( messages=messages, - context=None, + context=context, jinja2_variables=[], variable_pool=llm_node.graph_runtime_state.variable_pool, vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, - template_renderer=llm_node._template_renderer, ) - assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] - llm_node._template_renderer.render_jinja2.assert_called_once_with( - template="Hello, {{ name }}", - inputs={}, + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert result[0].content == [ + TextPromptMessageContent( + data="Answer user's question with the following context:\n\n## Overview\nSends a JSON request." + ) + ] + + +def test_handle_list_messages_renders_jinja2_messages(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_node.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + jinja2_template_renderer=renderer, ) + assert prompt_messages == [ + SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_transform_chat_messages_prefers_jinja2_text(llm_node): + completion_template = LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="completion prompt", + edition_type="jinja2", + ) + chat_messages = [ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="chat prompt", + role=PromptMessageRole.USER, + edition_type="jinja2", + ), + LLMNodeChatModelMessage( + text="keep original", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + ] + + transformed_completion = llm_node._transform_chat_messages(completion_template) + transformed_messages = llm_node._transform_chat_messages(chat_messages) + + assert transformed_completion.text == "completion prompt" + assert transformed_messages[0].text == "chat prompt" + assert transformed_messages[1].text == "keep original" + + +def test_fetch_jinja_inputs_serializes_supported_segment_types(llm_node): + llm_node.graph_runtime_state.variable_pool.add( + ["input", "items"], + ["alpha", {"metadata": {"_source": "knowledge"}, "content": "beta"}, 3], + ) + llm_node.graph_runtime_state.variable_pool.add( + ["input", "context_doc"], + {"metadata": {"_source": "knowledge"}, "content": "context body"}, + ) + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"a": 1}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[ + VariableSelector(variable="items", value_selector=["input", "items"]), + VariableSelector(variable="context_doc", value_selector=["input", "context_doc"]), + VariableSelector(variable="payload", value_selector=["input", "payload"]), + ] + ) + } + ) + + assert llm_node._fetch_jinja_inputs(node_data) == { + "items": "alpha\nbeta\n3", + "context_doc": "context body", + "payload": '{"a": 1}', + } + + +def test_fetch_jinja_inputs_raises_for_missing_variable(llm_node): + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[VariableSelector(variable="missing", value_selector=["input", "missing"])] + ) + } + ) + + with pytest.raises(VariableNotFoundError, match="Variable missing not found"): + llm_node._fetch_jinja_inputs(node_data) + + +def test_fetch_inputs_collects_prompt_and_memory_variables(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"active": True}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_template": [ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}} with {{#input.payload#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + "memory": MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#input.name#}}", + ), + } + ) + + assert llm_node._fetch_inputs(node_data) == { + "#input.name#": "Dify", + "#input.payload#": {"active": True}, + } + + +def test_fetch_context_emits_string_context_event(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], "retrieved context") + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert events == [ + RunRetrieverResourceEvent(retriever_resources=[], context="retrieved context", context_files=[]), + ] + + +def test_fetch_context_collects_retriever_resources_and_attachments(llm_node): + attachment = _build_image_file( + file_id="attachment", + related_id="attachment-related", + remote_url="https://example.com/attachment.png", + ) + llm_node._retriever_attachment_loader = mock.MagicMock() + llm_node._retriever_attachment_loader.load.return_value = [attachment] + + llm_node.graph_runtime_state.variable_pool.add( + ["context", "value"], + [ + { + "content": "chunk body", + "summary": "chunk summary", + "files": [{"id": "file-1"}], + "metadata": { + "_source": "knowledge", + "dataset_id": "dataset-1", + "segment_id": "segment-1", + "segment_word_count": 12, + }, + }, + "tail text", + ], + ) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert len(events) == 1 + event = events[0] + assert event.context == "chunk summary\nchunk body\ntail text" + assert event.context_files == [attachment] + assert event.retriever_resources == [ + { + "position": None, + "dataset_id": "dataset-1", + "dataset_name": None, + "document_id": None, + "document_name": None, + "data_source_type": None, + "segment_id": "segment-1", + "retriever_from": None, + "score": None, + "hit_count": None, + "word_count": 12, + "segment_position": None, + "index_node_hash": None, + "content": "chunk body", + "page": None, + "doc_metadata": None, + "files": [{"id": "file-1"}], + "summary": "chunk summary", + } + ] + llm_node._retriever_attachment_loader.load.assert_called_once_with(segment_id="segment-1") + + +def test_fetch_context_rejects_invalid_context_structure(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], [{"summary": "missing content"}]) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + with pytest.raises(InvalidContextStructureError, match="Invalid context structure"): + list(llm_node._fetch_context(node_data)) + + +def test_fetch_prompt_messages_chat_mode_appends_memory_query_and_files(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[ModelFeature.VISION]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history answer")] + + sys_file = _build_image_file(file_id="sys-file", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context-file", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + + prompt_content_side_effect = [ + ImagePromptMessageContent( + url="https://example.com/sys.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + url="https://example.com/context.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch("graphon.nodes.llm.node.file_manager.to_prompt_message_content") as mock_to_prompt: + mock_to_prompt.side_effect = prompt_content_side_effect + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + ), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history answer") + assert isinstance(prompt_messages[2], UserPromptMessage) + assert isinstance(prompt_messages[2].content, list) + assert isinstance(prompt_messages[2].content[0], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[1], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[2], TextPromptMessageContent) + assert prompt_messages[2].content[0].url == "https://example.com/context.png" + assert prompt_messages[2].content[1].url == "https://example.com/sys.png" + assert prompt_messages[2].content[2].data == "current question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=None) + + +def test_fetch_prompt_messages_completion_mode_injects_histories_and_query(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + +def test_fetch_prompt_messages_raises_when_only_unsupported_content_remains(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + variable_pool = VariablePool.empty() + variable_pool.add( + ["input", "image"], + _build_image_file(file_id="image-file", related_id="image-related", remote_url="https://example.com/file.png"), + ) + + with ( + mock.patch( + "graphon.nodes.llm.node.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + url="https://example.com/file.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + pytest.raises(NoPromptFoundError, match="No prompt found"), + ): + LLMNode.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + ) + + +def test_handle_completion_template_replaces_double_brace_context_placeholder(llm_node): + prompt_messages = _handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize the following context:\n{{#context#}}", + edition_type="basic", + ), + context="## Overview\nSends a JSON request.", + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_template_renderer=None, + ) + + assert prompt_messages == [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Summarize the following context:\n## Overview\nSends a JSON request.") + ] + ) + ] + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) @@ -635,15 +1148,15 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): AssistantPromptMessage(content="first answer"), ] - model_instance = mock.MagicMock(spec=ModelInstance) + model_instance = _build_prepared_llm_mock() memory_config = MemoryConfig( role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = llm_utils.handle_memory_completion_mode( + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -659,7 +1172,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -672,9 +1185,9 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node, mock_file_saver @@ -690,7 +1203,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -721,7 +1233,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -776,7 +1287,6 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: mock_saved_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", @@ -906,3 +1416,322 @@ class TestReasoningFormat: assert clean_text == text_with_think assert reasoning_content == "" + + +@pytest.mark.parametrize( + ("structured_output_enabled", "structured_output"), + [ + (False, None), + (True, {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}), + ], +) +def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enabled, structured_output): + model_instance = _build_prepared_llm_mock() + prompt_messages = [UserPromptMessage(content="hello")] + file_saver = mock.MagicMock(spec=LLMFileSaver) + + model_instance.invoke_llm.return_value = iter([]) + model_instance.invoke_llm_with_structured_output.return_value = iter([]) + + with ( + mock.patch.object(LLMNode, "handle_invoke_result", return_value=iter(["handled"])) as mock_handle, + mock.patch("graphon.nodes.llm.node.time.perf_counter", return_value=10.0), + ): + result = list( + LLMNode.invoke_llm( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=("STOP",), + structured_output_enabled=structured_output_enabled, + structured_output=structured_output, + file_saver=file_saver, + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + reasoning_format="separated", + ) + ) + + assert result == ["handled"] + if structured_output_enabled: + model_instance.invoke_llm_with_structured_output.assert_called_once_with( + prompt_messages=prompt_messages, + json_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, + model_parameters={}, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm.assert_not_called() + else: + model_instance.invoke_llm.assert_called_once_with( + prompt_messages=prompt_messages, + model_parameters={}, + tools=None, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm_with_structured_output.assert_not_called() + + assert mock_handle.call_args.kwargs["request_start_time"] == 10.0 + + +def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_output(): + usage = LLMUsage.from_metadata({"prompt_tokens": 12, "completion_tokens": 4, "total_tokens": 16}) + first_chunk = LLMResultChunkWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="plan")]), + ), + structured_output={"draft": True}, + ) + final_chunk = LLMResultChunk( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="answer")]), + usage=usage, + finish_reason="stop", + ), + ) + + with mock.patch("graphon.nodes.llm.node.time.perf_counter", side_effect=[2.0, 5.0]): + events = list( + LLMNode.handle_invoke_result( + invoke_result=iter([first_chunk, final_chunk]), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=_build_prepared_llm_mock(), + reasoning_format="separated", + request_start_time=1.0, + ) + ) + + assert events[0] == first_chunk + assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="plan", is_final=False) + assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + + completed = events[3] + assert isinstance(completed, ModelInvokeCompletedEvent) + assert completed.text == "answer" + assert completed.reasoning_content == "plan" + assert completed.structured_output == {"draft": True} + assert completed.finish_reason == "stop" + assert completed.usage.total_tokens == 16 + assert completed.usage.latency == 4.0 + assert completed.usage.time_to_first_token == 1.0 + assert completed.usage.time_to_generate == 3.0 + + +def test_handle_invoke_result_wraps_structured_output_parse_errors(): + model_instance = _build_prepared_llm_mock() + model_instance.is_structured_output_parse_error.return_value = True + + def broken_stream(): + raise ValueError("bad json") + yield + + with pytest.raises(LLMNodeError, match="Failed to parse structured output: bad json"): + list( + LLMNode.handle_invoke_result( + invoke_result=broken_stream(), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=model_instance, + ) + ) + + +def test_handle_blocking_result_extracts_reasoning_and_structured_output(): + invoke_result = LLMResultWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + message=AssistantPromptMessage(content="reasoningfinal answer"), + usage=LLMUsage.empty_usage(), + structured_output={"answer": "final answer"}, + ) + + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + reasoning_format="separated", + request_latency=1.2345, + ) + + assert event.text == "final answer" + assert event.reasoning_content == "reasoning" + assert event.structured_output == {"answer": "final answer"} + assert event.usage.latency == 1.234 + + +def test_fetch_structured_output_schema_validates_payload(): + assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object"}}) == { + "type": "object" + } + + with pytest.raises(LLMNodeError, match="Please provide a valid structured output schema"): + LLMNode.fetch_structured_output_schema(structured_output={}) + + with pytest.raises(LLMNodeError, match="structured_output_schema must be a JSON object"): + LLMNode.fetch_structured_output_schema(structured_output={"schema": ["not", "an", "object"]}) + + +def test_extract_variable_selector_to_variable_mapping_includes_runtime_selectors(): + node_data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ), + ], + prompt_config=PromptConfig( + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])] + ), + memory=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#sys.query#}}", + ), + context=ContextConfig(enabled=True, variable_selector=["context", "value"]), + vision=VisionConfig(enabled=True), + ) + + mapping = LLMNode._extract_variable_selector_to_variable_mapping( + graph_config={}, + node_id="llm-1", + node_data=node_data, + ) + + assert mapping == { + "llm-1.#input.name#": ["input", "name"], + "llm-1.#sys.query#": ["sys", "query"], + "llm-1.#context#": ["context", "value"], + "llm-1.#files#": ["sys", "files"], + "llm-1.name": ["input", "name"], + } + + +def test_render_jinja2_message_requires_renderer_and_passes_inputs(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + with pytest.raises( + TemplateRenderError, + match="LLMNode requires an injected jinja2_template_renderer for jinja2 prompts", + ): + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + assert ( + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=renderer, + ) + == "Hello Dify" + ) + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_calculate_rest_token_uses_context_size_and_max_tokens(): + model_instance = _build_prepared_llm_mock() + model_instance.parameters = {"max_tokens": 512} + model_instance.get_model_schema.return_value = _build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + ) + ], + ) + model_instance.get_llm_num_tokens.return_value = 1000 + + assert ( + _calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 2584 + ) + + +def test_handle_memory_chat_mode_uses_calculated_token_budget(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + history = [UserPromptMessage(content="question")] + memory.get_history_prompt_messages.return_value = history + + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=321) as mock_rest_token: + result = _handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=_build_prepared_llm_mock(), + ) + + assert result == history + mock_rest_token.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager_factory: + DifyCredentialsProvider(run_context=run_context, provider_manager=mock.MagicMock()) + DifyModelFactory(run_context=run_context, model_manager=mock.MagicMock()) + + mock_provider_manager_factory.assert_not_called() + + +def test_build_dify_model_access_binds_run_context_user_id_once(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager: + build_dify_model_access(run_context) + + mock_provider_manager.assert_called_once_with(tenant_id="tenant", user_id="user") + + +def test_dify_model_access_requires_run_context_argument(): + with pytest.raises(TypeError): + DifyCredentialsProvider() + + with pytest.raises(TypeError): + DifyModelFactory() diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index e40d565ef5..af1cff4e81 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -2,10 +2,10 @@ from collections.abc import Mapping, Sequence from pydantic import BaseModel, Field -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage +from graphon.file import File +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.nodes.llm.entities import LLMNodeChatModelMessage class LLMNodeTestScenario(BaseModel): diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py index fd48edc58c..ccf1077838 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py @@ -1,5 +1,5 @@ -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig -from dify_graph.variables.types import SegmentType +from graphon.nodes.parameter_extractor.entities import ParameterConfig +from graphon.variables.types import SegmentType class TestParameterConfig: diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 7eca531b62..8f8ec49f14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -7,18 +7,18 @@ from typing import Any import pytest -from dify_graph.model_runtime.entities import LLMMode -from dify_graph.nodes.llm import ModelConfig, VisionConfig -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData -from dify_graph.nodes.parameter_extractor.exc import ( +from factories.variable_factory import build_segment_with_type +from graphon.model_runtime.entities import LLMMode +from graphon.nodes.llm import ModelConfig, VisionConfig +from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.exc import ( InvalidNumberOfParametersError, InvalidSelectValueError, InvalidValueTypeError, RequiredParameterMissingError, ) -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.variables.types import SegmentType -from factories.variable_factory import build_segment_with_type +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.variables.types import SegmentType @dataclass diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py index e57ebbd83e..01878ed692 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from dify_graph.enums import ErrorStrategy -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.enums import ErrorStrategy +from graphon.nodes.template_transform.entities import TemplateTransformNodeData class TestTemplateTransformNodeData: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 332a8761f9..bc44ececd8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -3,11 +3,13 @@ from unittest.mock import MagicMock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -62,7 +64,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM @@ -78,7 +80,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" @@ -91,7 +93,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" @@ -111,7 +113,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -130,6 +132,26 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" + @pytest.mark.parametrize("max_output_length", [0, -1]) + def test_node_initialization_rejects_non_positive_max_output_length( + self, + basic_node_data, + mock_graph_runtime_state, + graph_init_params, + max_output_length, + ): + mock_renderer = MagicMock() + + with pytest.raises(ValueError, match="max_output_length must be a positive integer"): + TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=max_output_length, + ) + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool @@ -153,7 +175,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -181,7 +203,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -201,7 +223,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -221,7 +243,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, max_output_length=10, ) @@ -230,6 +252,28 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error + def test_run_output_length_equal_to_limit_succeeds( + self, basic_node_data, mock_graph_runtime_state, graph_init_params + ): + mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "1234567890" + + node = TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=10, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["output"] == "1234567890" + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { @@ -263,7 +307,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -291,6 +335,69 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] + def test_extract_variable_selector_to_variable_mapping_accepts_validated_node_data(self): + node_data = TemplateTransformNodeData( + title="Test", + variables=[VariableSelector(variable="var1", value_selector=["sys", "input1"])], + template="{{ var1 }}", + ) + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + + def test_extract_variable_selector_to_variable_mapping_returns_empty_mapping_without_variables(self): + node_data = { + "title": "Test", + "template": "{{ missing }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {} + + def test_extract_variable_selector_to_variable_mapping_accepts_sequence_value_selectors(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ("sys", "input1")}, + {"variable": "empty_selector", "value_selector": ()}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == { + "node_123.var1": ["sys", "input1"], + "node_123.empty_selector": [], + } + + def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ["sys", "input1"]}, + {"variable": "missing_selector"}, + ["not", "a", "mapping"], + {"variable": 1, "value_selector": ["sys", "input2"]}, + {"variable": "invalid_selector", "value_selector": ["sys", 2]}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -307,7 +414,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -346,7 +453,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -375,7 +482,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -405,7 +512,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py new file mode 100644 index 0000000000..636237e56e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.template_transform_node import ( + DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, + TemplateTransformNode, +) +from graphon.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params + +from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 + + +@pytest.fixture +def graph_init_params(): + return build_test_graph_init_params( + workflow_id="test_workflow", + graph_config={}, + tenant_id="test_tenant", + app_id="test_app", + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + mock_state = MagicMock(spec=GraphRuntimeState) + mock_state.variable_pool = MagicMock() + return mock_state + + +def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): + node = TemplateTransformNode( + id="test_node", + config={ + "id": "test_node", + "data": { + "title": "Template Transform", + "variables": [], + "template": "hello", + }, + }, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=MagicMock(), + ) + + assert node._max_output_length == DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH + + +def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entries(): + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={"ignored": True}, + node_id="node_123", + node_data={ + "variables": [ + VariableSelector(variable="validated", value_selector=["sys", "input1"]), + {"variable": "raw", "value_selector": ("sys", "input2")}, + {"variable": "invalid_selector", "value_selector": ["sys", 3]}, + ["not", "a", "mapping"], + ] + }, + ) + + assert mapping == { + "node_123.validated": ["sys", "input1"], + "node_123.raw": ["sys", "input2"], + } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 2b0205fb7b..0522dd9d14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -3,13 +3,14 @@ from collections.abc import Mapping import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +36,7 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) return init_params, runtime_state @@ -67,7 +68,7 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == "account" assert dify_ctx.invoke_from == "debugger" @@ -80,7 +81,7 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) @@ -91,7 +92,7 @@ def test_node_accepts_invoke_from_enum(): graph_runtime_state=runtime_state, ) - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == UserFrom.ACCOUNT assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER assert node.get_run_context_value("missing") is None @@ -127,3 +128,29 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" assert node_data.get("missing", "fallback") == "fallback" + + +def test_node_hydration_preserves_compatibility_extra_fields(): + graph_config: dict[str, object] = {} + init_params, runtime_state = _build_context(graph_config) + node_config = NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + "compat_flag": True, + }, + } + ) + + node = _SampleNode( + id="node-1", + config=node_config, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + assert node.node_data.foo == "bar" + assert node.node_data.get("compat_flag") is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 40754974c1..87ec2d5bce 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -6,21 +6,21 @@ import pytest from docx.oxml.text.paragraph import CT_P from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData -from dify_graph.nodes.document_extractor.node import ( +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment -from dify_graph.variables.variables import StringVariable +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment +from graphon.variables.variables import StringVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -183,14 +183,14 @@ def test_run_extract_text( mock_response.raise_for_status = Mock() document_extractor_node._http_client.get = Mock(return_value=mock_response) - monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) + monkeypatch.setattr("graphon.file.file_manager.download", mock_download) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) result = document_extractor_node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c746a945fe..782750e02e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,19 +4,18 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.graph import Graph -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition -from dify_graph.variables import ArrayFileSegment +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.graph import Graph +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.nodes.if_else.if_else_node import IfElseNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from graphon.variables import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +34,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -142,7 +141,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) @@ -253,7 +252,6 @@ def test_array_file_contains_file_name(): node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -316,7 +314,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -371,7 +369,7 @@ def test_execute_if_else_boolean_false_conditions(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -440,7 +438,7 @@ def test_execute_if_else_boolean_cases_structure(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 6ca72b64b2..b217e4e8e7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,11 +2,10 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.list_operator.entities import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.list_operator.entities import ( ExtractConfig, FilterBy, FilterCondition, @@ -15,9 +14,9 @@ from dify_graph.nodes.list_operator.entities import ( Order, OrderByConfig, ) -from dify_graph.nodes.list_operator.exc import InvalidKeyError -from dify_graph.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func -from dify_graph.variables import ArrayFileSegment +from graphon.nodes.list_operator.exc import InvalidKeyError +from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from graphon.variables import ArrayFileSegment @pytest.fixture @@ -72,7 +71,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image1.jpg", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", @@ -80,7 +78,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="document1.pdf", type=FileType.DOCUMENT, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", @@ -88,7 +85,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image2.png", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", @@ -96,7 +92,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="audio1.mp3", type=FileType.AUDIO, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -120,14 +115,12 @@ def test_filter_files_by_type(list_operator_node): { "filename": "document1.pdf", "type": FileType.DOCUMENT, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related2", }, { "filename": "image2.png", "type": FileType.IMAGE, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related3", }, @@ -136,7 +129,6 @@ def test_filter_files_by_type(list_operator_node): for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type - assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id @@ -144,7 +136,6 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", @@ -165,7 +156,6 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py index 6372583839..d613ba154a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -1,6 +1,22 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.nodes.loop.entities import LoopNodeData -from dify_graph.nodes.loop.loop_node import LoopNode +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph_events import GraphRunAbortedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent +from graphon.nodes.loop.entities import LoopNodeData +from graphon.nodes.loop.loop_node import LoopNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: @@ -50,3 +66,85 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf ) assert seen_configs == [child_node_config] + + +def test_run_single_loop_raises_on_child_abort_event() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(RuntimeError, match="quota exceeded"): + list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) + + +def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=2, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), + ) + create_graph_engine = MagicMock(return_value=aborting_engine) + node._create_graph_engine = create_graph_engine + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], LoopStartedEvent) + assert isinstance(events[1], LoopFailedEvent) + assert events[1].error == "quota exceeded" + assert isinstance(events[2], StreamCompletedEvent) + assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[2].node_run_result.error == "quota exceeded" + create_graph_engine.assert_called_once() + + +def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), + ) + node._create_graph_engine = MagicMock(return_value=aborting_engine) + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index c5a02e87e4..efbf786a55 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -1,13 +1,14 @@ from types import SimpleNamespace from unittest.mock import MagicMock -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.nodes.question_classifier import ( +from graphon.model_runtime.entities import ImagePromptMessageContent +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.protocols import HttpClientProtocol +from graphon.nodes.question_classifier import ( QuestionClassifierNode, QuestionClassifierNodeData, ) +from graphon.template_rendering import Jinja2TemplateRenderer from tests.workflow_test_utils import build_test_graph_init_params @@ -86,7 +87,7 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon "instruction": "This is a test instruction", } ) - template_renderer = MagicMock(spec=TemplateRenderer) + template_renderer = MagicMock(spec=Jinja2TemplateRenderer) node = QuestionClassifierNode( id="node-id", config={"id": "node-id", "data": node_data.model_dump(mode="json")}, @@ -107,11 +108,11 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon ) fetch_prompt_messages = MagicMock(return_value=([], None)) monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", + "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", fetch_prompt_messages, ) monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", + "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index b8f0e25e91..543f9878de 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,19 +4,22 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.variables import Variable +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def make_start_node(user_inputs, variables): - variable_pool = VariablePool( - system_variables=SystemVariable(), - user_inputs=user_inputs, - conversation_variables=[], + variable_pool = build_test_variable_pool( + variables=build_system_variables(), + node_id="start", + inputs=user_inputs, ) config = { @@ -232,3 +235,64 @@ def test_json_object_optional_variable_not_provided(): # Current implementation raises a validation error even when the variable is optional with pytest.raises(ValueError, match="profile is required in input form"): node._run() + + +def test_start_node_outputs_full_variable_pool_snapshot(): + variable_pool = build_test_variable_pool( + variables=[ + *build_system_variables(query="hello", workflow_run_id="run-123"), + _build_prefixed_variable(ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY", "secret"), + _build_prefixed_variable(CONVERSATION_VARIABLE_NODE_ID, "session_id", "conversation-1"), + ], + node_id="start", + inputs={"profile": {"age": 20, "name": "Tom"}}, + ) + + config = { + "id": "start", + "data": StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ).model_dump(), + } + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node = StartNode( + id="start", + config=config, + graph_init_params=build_test_graph_init_params( + workflow_id="wf", + graph_config={}, + tenant_id="tenant", + app_id="app", + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + result = node._run() + + assert result.inputs == {"profile": {"age": 20, "name": "Tom"}} + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + assert result.outputs["sys.query"] == "hello" + assert result.outputs["sys.workflow_run_id"] == "run-123" + assert result.outputs["env.API_KEY"] == "secret" + assert result.outputs["conversation.session_id"] == "conversation-1" + + +def _build_prefixed_variable(node_id: str, name: str, value: object) -> Variable: + return segment_to_variable( + segment=build_segment(value), + selector=(node_id, name), + name=name, + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3cbd96dfef..c806181340 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -3,23 +3,55 @@ from __future__ import annotations import sys import types from collections.abc import Generator +from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment +from core.workflow.system_variables import build_system_variables +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.tool.tool_node import ToolNode + + +class _StubToolRuntime: + def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + raise NotImplementedError + + def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: + return [] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: dict[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + yield from () + + def get_usage(self, *, tool_runtime: ToolRuntimeHandle) -> LLMUsage: + return LLMUsage.empty_usage() + + def build_file_reference(self, *, mapping: dict[str, Any]) -> Any: + return mapping + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + return default_icon, None @pytest.fixture @@ -31,8 +63,8 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) - from dify_graph.nodes.protocols import ToolFileManagerProtocol - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.protocols import ToolFileManagerProtocol + from graphon.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { "nodes": [ @@ -66,13 +98,14 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id")) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + runtime = _StubToolRuntime() node = ToolNode( id="node-instance", @@ -80,6 +113,7 @@ def tool_node(monkeypatch) -> ToolNode: graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=runtime, ) return node @@ -93,29 +127,19 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: - def _identity_transform(messages, *_args, **_kwargs): - return messages - - tool_runtime = MagicMock() - with patch.object( - ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True - ): - generator = tool_node._transform_message( - messages=iter([message]), - tool_info={"provider_type": "builtin", "provider_id": "provider"}, - parameters_for_log={}, - user_id="user-id", - tenant_id="tenant-id", - node_id=tool_node._node_id, - tool_runtime=tool_runtime, - ) - return _collect_events(generator) +def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: + generator = tool_node._transform_message( + messages=iter([message]), + tool_info={"provider_type": "builtin", "provider_id": "provider"}, + parameters_for_log={}, + node_id=tool_node._node_id, + tool_runtime=ToolRuntimeHandle(raw=object()), + ) + return _collect_events(generator) def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - tenant_id="tenant-id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", @@ -125,9 +149,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): size=123, storage_key="file-key", ) - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), meta={"file": file_obj}, ) @@ -150,9 +174,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): def test_plain_link_messages_remain_links(tool_node: ToolNode): - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="https://dify.ai"), meta=None, ) @@ -167,3 +191,35 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [] + + +def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): + file_obj = File( + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="file-id", + filename="demo.pdf", + extension=".pdf", + mime_type="application/pdf", + size=123, + storage_key="file-key", + ) + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + None, + SimpleNamespace(mime_type="application/pdf"), + ) + tool_node._runtime.build_file_reference = MagicMock(return_value=file_obj) + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.IMAGE_LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), + meta={"tool_file_id": "file-id"}, + ) + + events, _ = _run_transform(tool_node, message) + + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert len(completed_events) == 1 + files_segment = completed_events[0].node_run_result.outputs["files"] + assert isinstance(files_segment, ArrayFileSegment) + assert files_segment.value == [file_obj] diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py new file mode 100644 index 0000000000..438af211f3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool + + +@pytest.fixture +def runtime(monkeypatch) -> DifyToolNodeRuntime: + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute + ops_stub.TraceTask = object # pragma: no cover - stub attribute + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + init_params = build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + return DifyToolNodeRuntime(init_params.run_context) + + +def _build_tool_node_data() -> ToolNodeData: + return ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + +def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: + core_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + meta=None, + ) + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables(conversation_id="conversation-id") + ) + workflow_tool = MagicMock() + + with ( + patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_tool), + patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, + patch.object( + ToolFileMessageTransformer, + "transform_tool_invoke_messages", + side_effect=lambda *, messages, **_: messages, + ) as transform_tool_messages, + ): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + ) + messages = list( + runtime.invoke( + tool_runtime=tool_runtime, + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + ) + + assert not hasattr(tool_runtime, "conversation_id") + assert len(messages) == 1 + graph_message = messages[0] + assert graph_message.type == ToolRuntimeMessage.MessageType.LINK + assert isinstance(graph_message.message, ToolRuntimeMessage.TextMessage) + assert graph_message.message.text == "https://dify.ai" + + callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] + assert isinstance(callback, DifyWorkflowCallbackHandler) + assert generic_invoke_mock.call_args.kwargs["conversation_id"] == "conversation-id" + + transform_kwargs = transform_tool_messages.call_args.kwargs + assert transform_kwargs["conversation_id"] == "conversation-id" + + +def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: + invoke_error = PluginInvokeError('{"error_type":"RateLimit","message":"too many"}') + + with patch.object(ToolEngine, "generic_invoke", side_effect=invoke_error): + with pytest.raises(ToolRuntimeInvocationError, match="An error occurred in the provider"): + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + + +def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None: + usage_payload = LLMUsage.empty_usage().model_dump() + usage_payload["total_tokens"] = 42 + + usage = runtime.get_usage( + tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=usage_payload)), + ) + + assert usage.total_tokens == 42 + + +def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: + node_data = _build_tool_node_data() + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: + tool_runtime = runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + + assert not hasattr(tool_runtime, "conversation_id") + workflow_tool = runtime_mock.call_args.args[3] + assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN + + +def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: + tool_runtime = ToolRuntimeHandle( + raw=SimpleNamespace( + get_merged_runtime_parameters=MagicMock( + return_value=[ + SimpleNamespace(name="city", required=True), + SimpleNamespace(name="country", required=False), + ] + ) + ) + ) + + parameters = runtime.get_runtime_parameters(tool_runtime=tool_runtime) + + assert [(parameter.name, parameter.required) for parameter in parameters] == [ + ("city", True), + ("country", False), + ] + + +def test_get_usage_returns_empty_usage_when_tool_has_no_usage(runtime: DifyToolNodeRuntime) -> None: + usage = runtime.get_usage(tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=None))) + + assert usage == LLMUsage.empty_usage() + + +@pytest.mark.parametrize( + ("payload", "expected_type"), + [ + (ToolInvokeMessage.JsonMessage(json_object={"ok": True}, suppress_output=True), ToolRuntimeMessage.JsonMessage), + (ToolInvokeMessage.BlobMessage(blob=b"bytes"), ToolRuntimeMessage.BlobMessage), + ( + ToolInvokeMessage.BlobChunkMessage( + id="blob-id", + sequence=1, + total_length=5, + blob=b"hello", + end=True, + ), + ToolRuntimeMessage.BlobChunkMessage, + ), + (ToolInvokeMessage.FileMessage(file_marker="marker"), ToolRuntimeMessage.FileMessage), + ( + ToolInvokeMessage.VariableMessage(variable_name="city", variable_value="Tokyo", stream=True), + ToolRuntimeMessage.VariableMessage, + ), + ( + ToolInvokeMessage.LogMessage( + id="log-id", + label="lookup", + status=ToolInvokeMessage.LogMessage.LogStatus.SUCCESS, + data={"count": 1}, + metadata={"source": "tool"}, + ), + ToolRuntimeMessage.LogMessage, + ), + ], +) +def test_convert_message_payload_supports_runtime_message_types( + runtime: DifyToolNodeRuntime, + payload: object, + expected_type: type[object], +) -> None: + message = runtime._convert_message_payload(payload) + + assert isinstance(message, expected_type) + + +def test_convert_message_payload_rejects_unknown_types(runtime: DifyToolNodeRuntime) -> None: + with pytest.raises(TypeError, match="unsupported tool message payload"): + runtime._convert_message_payload(object()) + + +def test_resolve_provider_icons_prefers_builtin_tool_icons(runtime: DifyToolNodeRuntime) -> None: + plugin = SimpleNamespace( + plugin_id="langgenius/tools", + name="search", + declaration=SimpleNamespace(icon={"plugin": "icon"}), + ) + builtin_tool = SimpleNamespace( + name="langgenius/tools/search", + icon={"builtin": "icon"}, + icon_dark={"builtin": "dark"}, + ) + + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[builtin_tool]), + ): + installer_cls.return_value.list_plugins.return_value = [plugin] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="langgenius/tools/search") + + assert icon == {"builtin": "icon"} + assert icon_dark == {"builtin": "dark"} + + +def test_resolve_provider_icons_returns_default_when_provider_is_unknown(runtime: DifyToolNodeRuntime) -> None: + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[]), + ): + installer_cls.return_value.list_plugins.return_value = [] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="unknown", default_icon="fallback") + + assert icon == "fallback" + assert icon_dark is None + + +@pytest.mark.parametrize( + ("exc", "message"), + [ + (PluginDaemonClientSideError("bad request"), "Failed to invoke tool, error: bad request"), + (ToolInvokeError("broken"), "Failed to invoke tool provider: broken"), + (RuntimeError("unexpected"), "unexpected"), + ], +) +def test_map_invocation_exception_normalizes_runtime_errors( + runtime: DifyToolNodeRuntime, + exc: Exception, + message: str, +) -> None: + error = runtime._map_invocation_exception(exc, provider_name="provider") + + assert isinstance(error, ToolRuntimeInvocationError) + assert str(error) == message diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 9aeab0409e..c8ddc53284 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -2,12 +2,12 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: @@ -17,9 +17,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable(user_id="user", files=[]), - user_inputs={"payload": "value"}, + variable_pool=build_test_variable_pool( + variables=build_system_variables(user_id="user", files=[]), + node_id="node-1", + inputs={"payload": "value"}, ), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index e69c05dc0b..fabc8df73e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,22 +2,38 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable, StringVariable +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.nodes.variable_assigner.common import helpers as common_helpers +from graphon.nodes.variable_assigner.v1 import VariableAssignerNode +from graphon.nodes.variable_assigner.v1.node_data import WriteMode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ArrayStringVariable, StringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool( + *, + conversation_id: str, + conversation_variables: list[StringVariable | ArrayStringVariable], +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=conversation_id), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_overwrite_string_variable(): graph_config = { "edges": [ @@ -71,10 +87,8 @@ def test_overwrite_string_variable(): conversation_id = str(uuid.uuid4()) # construct variable pool - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -108,16 +122,14 @@ def test_overwrite_string_variable(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == input_variable.value - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.value == "the second value" - assert got.to_object() == "the second value" + assert updated_event.variable.value == "the second value" + assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) def test_append_variable_to_array(): @@ -172,10 +184,8 @@ def test_append_variable_to_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) variable_pool.add( @@ -208,15 +218,13 @@ def test_append_variable_to_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == ["the first value", "the second value"] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["the first value", "the second value"] + assert updated_event.variable.value == ["the first value", "the second value"] def test_clear_array(): @@ -265,10 +273,8 @@ def test_clear_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -297,12 +303,10 @@ def test_clear_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == [] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py index a7673c5a14..9ac8bbe9c2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -1,6 +1,6 @@ -from dify_graph.nodes.variable_assigner.v2.enums import Operation -from dify_graph.nodes.variable_assigner.v2.helpers import is_input_value_valid -from dify_graph.variables import SegmentType +from graphon.nodes.variable_assigner.v2.enums import Operation +from graphon.nodes.variable_assigner.v2.helpers import is_input_value_valid +from graphon.variables import SegmentType def test_is_input_value_valid_overwrite_array_string(): diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 6874f3fef1..53346c4a90 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,20 +2,33 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_events import NodeRunVariableUpdatedEvent +from graphon.nodes.variable_assigner.v2 import VariableAssignerNode +from graphon.nodes.variable_assigner.v2.enums import InputType, Operation +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ArrayStringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id="conversation_id"), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_handle_item_directly(): """Test the _handle_item method directly for remove operations.""" # Create variables @@ -106,12 +119,7 @@ def test_remove_first_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -146,11 +154,8 @@ def test_remove_first_from_array(): # Run the node result = list(node.run()) - # Completed run - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["second", "third"] + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["second", "third"] def test_remove_last_from_array(): @@ -194,12 +199,7 @@ def test_remove_last_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -231,11 +231,9 @@ def test_remove_last_from_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["first", "second"] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["first", "second"] def test_remove_first_from_empty_array(): @@ -279,12 +277,7 @@ def test_remove_first_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -316,11 +309,9 @@ def test_remove_first_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_remove_last_from_empty_array(): @@ -364,12 +355,7 @@ def test_remove_last_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -401,11 +387,9 @@ def test_remove_last_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_node_factory_creates_variable_assigner_node(): @@ -433,12 +417,7 @@ def test_node_factory_creates_variable_assigner_node(): }, call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(conversation_variables=[]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 6be5bb23e8..be18391b2c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -324,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from dify_graph.entities.base_node_data import BaseNodeData + from graphon.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index ddf1af5a59..617554ee17 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -6,7 +6,7 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 78dd7ce0f3..6fbd26131d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,7 +8,7 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -16,11 +16,12 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node( @@ -96,6 +97,18 @@ def create_test_file_dict( } +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="webhook-node-1", + inputs=inputs, + ) + + +def expected_factory_mapping(file_dict: dict) -> dict: + return {**file_dict, "upload_file_id": file_dict["related_id"]} + + def test_webhook_node_file_conversion_to_file_variable(): """Test that webhook node converts file dictionaries to FileVariable objects.""" # Create test file dictionary (as it comes from webhook service) @@ -111,9 +124,8 @@ def test_webhook_node_file_conversion_to_file_variable(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -122,14 +134,14 @@ def test_webhook_node_file_conversion_to_file_variable(): "image_upload": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory and variable factory + # Mock the file reference boundary and variable factory with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -153,8 +165,7 @@ def test_webhook_node_file_conversion_to_file_variable(): # Verify file factory was called with correct parameters mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) # Verify segment factory was called to create FileSegment @@ -184,16 +195,15 @@ def test_webhook_node_file_conversion_with_missing_files(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, # No files } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -219,9 +229,8 @@ def test_webhook_node_file_conversion_with_none_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -230,7 +239,7 @@ def test_webhook_node_file_conversion_with_none_file(): "file": None, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -256,9 +265,8 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -267,7 +275,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): "file": "not_a_dict", # Wrapped to match node expectation }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -300,9 +308,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -315,13 +322,13 @@ def test_webhook_node_file_conversion_mixed_parameters(): "file_param": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -350,8 +357,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): # Verify file conversion was called mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) @@ -370,9 +376,8 @@ def test_webhook_node_different_file_types(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -383,13 +388,13 @@ def test_webhook_node_different_file_types(): "video": create_test_file_dict("video.mp4", "video"), }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -430,9 +435,8 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -441,7 +445,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): "file": "just a string", }, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 139f65d6c3..9f954b2090 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,7 +2,7 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -12,13 +12,14 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import FileVariable, StringVariable +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool +from graphon.variables import FileVariable, StringVariable +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -62,6 +63,14 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="1", + inputs=inputs, + ) + + def test_webhook_node_basic_initialization(): """Test basic webhook node initialization and configuration.""" data = WebhookData( @@ -76,10 +85,7 @@ def test_webhook_node_basic_initialization(): timeout=30, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - ) + variable_pool = build_webhook_variable_pool({}) node = create_webhook_node(data, variable_pool) @@ -119,9 +125,8 @@ def test_webhook_node_run_with_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "Authorization": "Bearer token123", @@ -132,7 +137,7 @@ def test_webhook_node_run_with_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -155,9 +160,8 @@ def test_webhook_node_run_with_query_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": { @@ -167,7 +171,7 @@ def test_webhook_node_run_with_query_params(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -191,9 +195,8 @@ def test_webhook_node_run_with_body_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -205,7 +208,7 @@ def test_webhook_node_run_with_body_params(): }, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -222,7 +225,6 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -232,7 +234,6 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - tenant_id="1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", @@ -250,9 +251,8 @@ def test_webhook_node_run_with_file_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -262,14 +262,14 @@ def test_webhook_node_run_with_file_params(): "document": file2.to_dict(), }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -284,7 +284,6 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -303,23 +302,22 @@ def test_webhook_node_run_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, "files": {"upload": file_obj.to_dict()}, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -343,10 +341,7 @@ def test_webhook_node_run_empty_webhook_data(): body=[WebhookBodyParameter(name="message", type="string", required=False)], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, # No webhook_data - ) + variable_pool = build_webhook_variable_pool({}) # No webhook_data node = create_webhook_node(data, variable_pool) result = node._run() @@ -369,9 +364,8 @@ def test_webhook_node_run_case_insensitive_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "content-type": "application/json", # lowercase @@ -382,7 +376,7 @@ def test_webhook_node_run_case_insensitive_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -399,12 +393,11 @@ def test_webhook_node_variable_pool_user_inputs(): data = WebhookData(title="Test Webhook") # Add some additional variables to the pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", - }, + } ) variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value")) @@ -430,16 +423,15 @@ def test_webhook_node_different_methods(method): method=method, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index e8ce6f60f7..453e0a8502 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -1,6 +1,6 @@ """Tests for workflow pause related enums and constants.""" -from dify_graph.enums import ( +from graphon.enums import ( WorkflowExecutionStatus, ) diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py new file mode 100644 index 0000000000..0623800b30 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +from pydantic import BaseModel + +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + is_human_input_webapp_enabled, + normalize_human_input_node_data_for_graph, + normalize_node_config_for_graph, + normalize_node_data_for_graph, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert "", - "'; DROP TABLE users; --", - "../../../etc/passwd", - "\\x00\\x00", # null bytes - "A" * 10000, # very long input - ], - ) - def test_validate_api_key_auth_args_malicious_input(self, malicious_input): - """Test API key auth args validation - malicious input""" - args = self.mock_args.copy() - args["category"] = malicious_input - - # Verify parameter validator doesn't crash on malicious input - # Should validate normally rather than raising security-related exceptions - ApiKeyAuthService.validate_api_key_auth_args(args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - database error handling""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption - mock_encrypter.encrypt_token.return_value = "encrypted_key" - - # Mock database error - mock_session.commit.side_effect = Exception("Database error") - - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_invalid_json(self, mock_session): - """Test get auth credentials - invalid JSON""" - # Mock database returning invalid JSON - mock_binding = Mock() - mock_binding.credentials = "invalid json content" - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - with pytest.raises(json.JSONDecodeError): - ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - def test_create_provider_auth_factory_exception(self, mock_factory, mock_session): - """Test create provider auth - factory exception""" - # Mock factory raising exception - mock_factory.side_effect = Exception("Factory error") - - with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - encryption exception""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption exception - mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") - - with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - def test_validate_api_key_auth_args_none_input(self): - """Test API key auth args validation - None input""" - with pytest.raises(TypeError): - ApiKeyAuthService.validate_api_key_auth_args(None) - - def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): - """Test API key auth args validation - dict credentials with list auth_type""" - args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] - - # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy - # So this should not raise exception, this test should pass - ApiKeyAuthService.validate_api_key_auth_args(args) diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py deleted file mode 100644 index 3832a0b8b2..0000000000 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -API Key Authentication System Integration Tests -""" - -import json -from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, patch - -import httpx -import pytest - -from services.auth.api_key_auth_factory import ApiKeyAuthFactory -from services.auth.api_key_auth_service import ApiKeyAuthService -from services.auth.auth_type import AuthType - - -class TestAuthIntegration: - def setup_method(self): - self.tenant_id_1 = "tenant_123" - self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing - self.category = "search" - - # Realistic authentication configurations - self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} - self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} - self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): - """Test complete authentication flow: request → validation → encryption → storage""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_fc_test_key_123" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_http.assert_called_once() - call_args = mock_http.call_args - assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] - assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" - - mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123") - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_cross_component_integration(self, mock_http): - """Test factory → provider → HTTP call integration""" - mock_http.return_value = self._create_success_response() - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - result = factory.validate_credentials() - - assert result is True - mock_http.assert_called_once() - - @patch("services.auth.api_key_auth_service.db.session") - def test_multi_tenant_isolation(self, mock_session): - """Ensure complete tenant data isolation""" - tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) - tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - - mock_session.scalars.return_value.all.return_value = [tenant1_binding] - result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - - mock_session.scalars.return_value.all.return_value = [tenant2_binding] - result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) - - assert len(result1) == 1 - assert result1[0].tenant_id == self.tenant_id_1 - assert len(result2) == 1 - assert result2[0].tenant_id == self.tenant_id_2 - - @patch("services.auth.api_key_auth_service.db.session") - def test_cross_tenant_access_prevention(self, mock_session): - """Test prevention of cross-tenant credential access""" - mock_session.query.return_value.where.return_value.first.return_value = None - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) - - assert result is None - - def test_sensitive_data_protection(self): - """Ensure API keys don't leak to logs""" - credentials_with_secrets = { - "auth_type": "bearer", - "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, - } - - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) - factory_str = str(factory) - - assert "super_secret_key_do_not_log" not in factory_str - assert "another_secret" not in factory_str - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): - """Test concurrent authentication creation safety""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_key" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - results = [] - exceptions = [] - - def create_auth(): - try: - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - results.append("success") - except Exception as e: - exceptions.append(e) - - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_auth) for _ in range(5)] - for future in futures: - future.result() - - assert len(results) == 5 - assert len(exceptions) == 0 - assert mock_session.add.call_count == 5 - assert mock_session.commit.call_count == 5 - - @pytest.mark.parametrize( - "invalid_input", - [ - None, # Null input - {}, # Empty dictionary - missing required fields - {"auth_type": "bearer"}, # Missing config section - {"auth_type": "bearer", "config": {}}, # Missing api_key - ], - ) - def test_invalid_input_boundary(self, invalid_input): - """Test boundary handling for invalid inputs""" - with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): - ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_http_error_handling(self, mock_http): - """Test proper HTTP error handling""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.text = '{"error": "Unauthorized"}' - mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") - mock_http.return_value = mock_response - - # PT012: Split into single statement for pytest.raises - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - with pytest.raises((httpx.HTTPError, Exception)): - factory.validate_credentials() - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_network_failure_recovery(self, mock_http, mock_session): - """Test system recovery from network failures""" - mock_http.side_effect = httpx.RequestError("Network timeout") - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_session.commit.assert_not_called() - - @pytest.mark.parametrize( - ("provider", "credentials"), - [ - (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), - (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), - (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), - ], - ) - def test_all_providers_factory_creation(self, provider, credentials): - """Test factory creation for all supported providers""" - auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) - assert auth_class is not None - - factory = ApiKeyAuthFactory(provider, credentials) - assert factory.auth is not None - - def _create_success_response(self, status_code=200): - """Create successful HTTP response mock""" - mock_response = Mock() - mock_response.status_code = status_code - mock_response.json.return_value = {"status": "success"} - mock_response.raise_for_status.return_value = None - return mock_response - - def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock: - """Create realistic database binding mock""" - mock_binding = Mock() - mock_binding.id = f"binding_{provider}_{tenant_id}" - mock_binding.tenant_id = tenant_id - mock_binding.category = self.category - mock_binding.provider = provider - mock_binding.credentials = json.dumps(credentials, ensure_ascii=False) - mock_binding.disabled = False - - mock_binding.created_at = Mock() - mock_binding.created_at.timestamp.return_value = 1640995200 - mock_binding.updated_at = Mock() - mock_binding.updated_at.timestamp.return_value = 1640995200 - - return mock_binding - - def test_integration_coverage_validation(self): - """Validate integration test coverage meets quality standards""" - core_scenarios = { - "business_logic": ["end_to_end_auth_flow", "cross_component_integration"], - "security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"], - "reliability": ["concurrent_creation_safety", "network_failure_recovery"], - "compatibility": ["all_providers_factory_creation"], - "boundaries": ["invalid_input_boundary", "http_error_handling"], - } - - total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values()) - assert total_scenarios >= 10 - - security_tests = core_scenarios["security"] - assert "multi_tenant_isolation" in security_tests - assert "sensitive_data_protection" in security_tests - assert True diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py new file mode 100644 index 0000000000..c95b60fad0 --- /dev/null +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -0,0 +1,455 @@ +"""Shared helpers for dataset_service unit tests. + +These factories and lightweight builders are reused across the dataset, +document, and segment service test modules that exercise +``api/services/dataset_service.py``. +""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, create_autospec, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from enums.cloud_plan import CloudPlan +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from models import Account, TenantAccountRole +from models.dataset import ( + ChildChunk, + Dataset, + DatasetPermissionEnum, + DatasetProcessRule, + Document, + DocumentSegment, +) +from models.model import UploadFile +from services.dataset_service import ( + DatasetCollectionBindingService, + DatasetPermissionService, + DatasetService, + DocumentService, + SegmentService, +) +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + DataSource, + FileInfo, + InfoList, + KnowledgeConfig, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + ProcessRule, + RerankingModel, + RetrievalModel, + Rule, + Segmentation, + SegmentUpdateArgs, + WebsiteInfo, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + IconInfo as PipelineIconInfo, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + RerankingModelConfig as RagPipelineRerankingModelConfig, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + RetrievalSetting as RagPipelineRetrievalSetting, +) +from services.errors.account import NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.dataset import DatasetNameDuplicateError +from services.errors.document import DocumentIndexingError +from services.errors.file import FileNotExistsError + +__all__ = [ + "Account", + "BuiltInField", + "ChildChunk", + "ChildChunkDeleteIndexError", + "ChildChunkIndexingError", + "ChildChunkUpdateArgs", + "CloudPlan", + "DataSource", + "Dataset", + "DatasetCollectionBindingService", + "DatasetNameDuplicateError", + "DatasetPermissionEnum", + "DatasetPermissionService", + "DatasetProcessRule", + "DatasetService", + "DatasetServiceUnitDataFactory", + "Document", + "DocumentIndexingError", + "DocumentSegment", + "DocumentService", + "FileInfo", + "FileNotExistsError", + "Forbidden", + "IndexStructureType", + "InfoList", + "KnowledgeConfig", + "KnowledgeConfiguration", + "LLMBadRequestError", + "MagicMock", + "Mock", + "ModelFeature", + "ModelType", + "NoPermissionError", + "NotFound", + "NotionIcon", + "NotionInfo", + "NotionPage", + "PipelineIconInfo", + "PreProcessingRule", + "ProcessRule", + "ProviderTokenNotInitError", + "RagPipelineDatasetCreateEntity", + "RagPipelineRerankingModelConfig", + "RagPipelineRetrievalSetting", + "RerankingModel", + "RetrievalMethod", + "RetrievalModel", + "Rule", + "SegmentService", + "SegmentUpdateArgs", + "Segmentation", + "SimpleNamespace", + "TenantAccountRole", + "WebsiteInfo", + "_make_child_chunk", + "_make_dataset", + "_make_document", + "_make_features", + "_make_knowledge_configuration", + "_make_lock_context", + "_make_retrieval_model", + "_make_segment", + "_make_session_context", + "_make_upload_knowledge_config", + "create_autospec", + "json", + "patch", + "pytest", +] + + +def _make_session_context(session: MagicMock) -> MagicMock: + """Wrap a mocked session in a context manager.""" + context_manager = MagicMock() + context_manager.__enter__.return_value = session + context_manager.__exit__.return_value = False + return context_manager + + +class DatasetServiceUnitDataFactory: + """Factory for lightweight doubles used across dataset service tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + *, + permission: str = DatasetPermissionEnum.ALL_TEAM, + created_by: str = "user-123", + indexing_technique: str = "economy", + embedding_model_provider: str = "provider", + embedding_model: str = "model", + built_in_field_enabled: bool = False, + doc_form: str | None = "text_model", + enable_api: bool = False, + summary_index_setting: dict | None = None, + **kwargs, + ) -> Mock: + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.permission = permission + dataset.created_by = created_by + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + dataset.built_in_field_enabled = built_in_field_enabled + dataset.doc_form = doc_form + dataset.enable_api = enable_api + dataset.updated_by = None + dataset.updated_at = None + dataset.summary_index_setting = summary_index_setting + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-123", + tenant_id: str = "tenant-123", + role: str = TenantAccountRole.OWNER, + **kwargs, + ) -> SimpleNamespace: + user = SimpleNamespace( + id=user_id, + current_tenant_id=tenant_id, + current_role=role, + ) + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + @staticmethod + def create_document_mock( + document_id: str = "doc-123", + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + *, + indexing_status: str = "completed", + is_paused: bool = False, + archived: bool = False, + enabled: bool = True, + data_source_type: str = "upload_file", + data_source_info_dict: dict | None = None, + data_source_info: str | None = None, + doc_form: str = "text_model", + need_summary: bool = True, + position: int = 0, + doc_metadata: dict | None = None, + name: str = "Document", + **kwargs, + ) -> Mock: + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.indexing_status = indexing_status + document.is_paused = is_paused + document.paused_by = None + document.paused_at = None + document.archived = archived + document.enabled = enabled + document.data_source_type = data_source_type + document.data_source_info_dict = data_source_info_dict or {} + document.data_source_info = data_source_info + document.doc_form = doc_form + document.need_summary = need_summary + document.position = position + document.doc_metadata = doc_metadata + document.name = name + for key, value in kwargs.items(): + setattr(document, key, value) + return document + + @staticmethod + def create_upload_file_mock(file_id: str = "file-123", name: str = "upload.txt") -> Mock: + upload_file = Mock(spec=UploadFile) + upload_file.id = file_id + upload_file.name = name + return upload_file + + +_UNSET = object() + + +def _make_lock_context() -> MagicMock: + context_manager = MagicMock() + context_manager.__enter__.return_value = None + context_manager.__exit__.return_value = False + return context_manager + + +def _make_features(*, enabled: bool, plan: str = CloudPlan.PROFESSIONAL) -> SimpleNamespace: + return SimpleNamespace( + billing=SimpleNamespace( + enabled=enabled, + subscription=SimpleNamespace(plan=plan), + ), + documents_upload_quota=SimpleNamespace(limit=1000, size=0), + ) + + +def _make_dataset( + *, + dataset_id: str = "dataset-1", + tenant_id: str = "tenant-1", + data_source_type: str | None = None, + indexing_technique: str | None = "economy", + latest_process_rule=None, +) -> Mock: + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.data_source_type = data_source_type + dataset.indexing_technique = indexing_technique + dataset.latest_process_rule = latest_process_rule + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + dataset.summary_index_setting = None + dataset.retrieval_model = None + dataset.collection_binding_id = None + return dataset + + +def _make_document( + *, + document_id: str = "doc-1", + dataset_id: str = "dataset-1", + tenant_id: str = "tenant-1", + batch: str = "batch-1", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + word_count: int = 0, + name: str = "Document 1", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + display_status: str = "available", +) -> Mock: + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.batch = batch + document.doc_form = doc_form + document.word_count = word_count + document.name = name + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.display_status = display_status + document.data_source_type = "upload_file" + document.data_source_info = "{}" + document.completed_at = SimpleNamespace() + document.processing_started_at = "started" + document.parsing_completed_at = "parsed" + document.cleaning_completed_at = "cleaned" + document.splitting_completed_at = "split" + document.updated_at = None + document.created_from = None + document.dataset_process_rule_id = "process-rule-1" + return document + + +def _make_segment( + *, + segment_id: str = "segment-1", + content: str = "segment content", + word_count: int = 15, + enabled: bool = True, + keywords: list[str] | None = None, + index_node_id: str = "node-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", +) -> Mock: + segment = Mock(spec=DocumentSegment) + segment.id = segment_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.word_count = word_count + segment.enabled = enabled + segment.keywords = keywords or [] + segment.answer = None + segment.index_node_id = index_node_id + segment.disabled_at = None + segment.disabled_by = None + segment.status = "completed" + segment.error = None + return segment + + +def _make_child_chunk() -> ChildChunk: + return ChildChunk( + id="child-a", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=1, + content="old content", + word_count=11, + created_by="user-1", + ) + + +def _make_upload_knowledge_config( + *, + original_document_id: str | None = None, + file_ids: list[str] | None = None, + process_rule: ProcessRule | None = None, + data_source: DataSource | object | None = _UNSET, +) -> KnowledgeConfig: + if data_source is _UNSET: + info_list = InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=file_ids) if file_ids is not None else None, + ) + data_source = DataSource(info_list=info_list) + + return KnowledgeConfig( + original_document_id=original_document_id, + indexing_technique="economy", + data_source=data_source, + process_rule=process_rule, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + +def _make_retrieval_model( + *, + reranking_provider_name: str = "rerank-provider", + reranking_model_name: str = "rerank-model", +) -> RetrievalModel: + return RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name=reranking_provider_name, + reranking_model_name=reranking_model_name, + ), + reranking_mode="reranking_model", + top_k=4, + score_threshold_enabled=False, + ) + + +def _make_rag_pipeline_retrieval_setting() -> RagPipelineRetrievalSetting: + return RagPipelineRetrievalSetting( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + top_k=4, + score_threshold=0.5, + score_threshold_enabled=True, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RagPipelineRerankingModelConfig( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + ) + + +def _make_knowledge_configuration( + *, + chunk_structure: str = "paragraph", + indexing_technique: str = "high_quality", + embedding_model_provider: str = "provider", + embedding_model: str = "embedding-model", + keyword_number: int = 8, + summary_index_setting: dict | None = None, +) -> KnowledgeConfiguration: + return KnowledgeConfiguration( + chunk_structure=chunk_structure, + indexing_technique=indexing_technique, + embedding_model_provider=embedding_model_provider, + embedding_model=embedding_model, + keyword_number=keyword_number, + retrieval_model=_make_rag_pipeline_retrieval_setting(), + summary_index_setting=summary_index_setting, + ) diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index c805dd98e2..62c39f96d3 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory: name: str = "Test Dataset", description: str = "Test description", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str | None = "openai", embedding_model: str | None = "text-embedding-ada-002", collection_binding_id: str | None = "binding-123", @@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_knowledge_configuration_mock( chunk_structure: str = "tree", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", keyword_number: int = 10, @@ -591,7 +592,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: # Assert assert dataset.chunk_structure == "list" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == "binding-123" @@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", # Existing structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", # Different structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) mock_session.merge.return_value = dataset @@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( dataset_id="dataset-123", runtime_mode="rag_pipeline", - indexing_technique="high_quality", # Current technique + indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - indexing_technique="economy", # Trying to change to economy + indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy ) mock_session.merge.return_value = dataset diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 6829691507..3358c8b44d 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,7 +111,8 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( @@ -153,7 +154,7 @@ class DocumentValidationTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", **kwargs, @@ -188,8 +189,8 @@ class DocumentValidationTestDataFactory: def create_knowledge_config_mock( data_source: DataSource | None = None, process_rule: ProcessRule | None = None, - doc_form: str = "text_model", - indexing_technique: str = "high_quality", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, **kwargs, ) -> Mock: """ @@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm: - Validation logic works correctly """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "text_model" + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None) - doc_form = "text_model" + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm: - Error type is correct """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "table_model" # Different form + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card") - doc_form = "text_model" # Different form + doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -430,7 +431,7 @@ class TestDatasetServiceCheckDatasetModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_dataset_model_setting_high_quality_success(self, mock_model_manager): @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -480,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting: - No errors are raised """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) # Act (should not raise) DatasetService.check_dataset_model_setting(dataset) @@ -502,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="invalid-model", ) @@ -532,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -579,7 +580,7 @@ class TestDatasetServiceCheckEmbeddingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_embedding_model_setting_success(self, mock_model_manager): @@ -701,7 +702,7 @@ class TestDatasetServiceCheckRerankingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_reranking_model_setting_success(self, mock_model_manager): diff --git a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py deleted file mode 100644 index b66111902c..0000000000 --- a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Unit tests for account deletion synchronization. - -This test module verifies the enterprise account deletion sync functionality, -including Redis queuing, error handling, and community vs enterprise behavior. -""" - -from unittest.mock import MagicMock, patch - -import pytest -from redis import RedisError - -from services.enterprise.account_deletion_sync import ( - _queue_task, - sync_account_deletion, - sync_workspace_member_removal, -) - - -class TestQueueTask: - """Unit tests for the _queue_task helper function.""" - - @pytest.fixture - def mock_redis_client(self): - """Mock redis_client for testing.""" - with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: - yield mock_redis - - @pytest.fixture - def mock_uuid(self): - """Mock UUID generation for predictable task IDs.""" - with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen: - mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234") - yield mock_uuid_gen - - def test_queue_task_success(self, mock_redis_client, mock_uuid): - """Test successful task queueing to Redis.""" - # Arrange - workspace_id = "ws-123" - member_id = "member-456" - source = "test_source" - - # Act - result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) - - # Assert - assert result is True - mock_redis_client.lpush.assert_called_once() - - # Verify the task payload structure - call_args = mock_redis_client.lpush.call_args[0] - assert call_args[0] == "enterprise:member:sync:queue" - - import json - - task_data = json.loads(call_args[1]) - assert task_data["workspace_id"] == workspace_id - assert task_data["member_id"] == member_id - assert task_data["source"] == source - assert task_data["type"] == "sync_member_deletion_from_workspace" - assert task_data["retry_count"] == 0 - assert "task_id" in task_data - assert "created_at" in task_data - - def test_queue_task_redis_error(self, mock_redis_client, caplog): - """Test handling of Redis connection errors.""" - # Arrange - mock_redis_client.lpush.side_effect = RedisError("Connection failed") - - # Act - result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - assert "Failed to queue account deletion sync" in caplog.text - - def test_queue_task_type_error(self, mock_redis_client, caplog): - """Test handling of JSON serialization errors.""" - # Arrange - mock_redis_client.lpush.side_effect = TypeError("Cannot serialize") - - # Act - result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - assert "Failed to queue account deletion sync" in caplog.text - - -class TestSyncWorkspaceMemberRemoval: - """Unit tests for sync_workspace_member_removal function.""" - - @pytest.fixture - def mock_queue_task(self): - """Mock _queue_task for testing.""" - with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: - mock_queue.return_value = True - yield mock_queue - - def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is True.""" - # Arrange - workspace_id = "ws-123" - member_id = "member-456" - source = "workspace_member_removed" - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source) - - # Assert - assert result is True - mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source) - - def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is False (community edition).""" - # Arrange - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = False - - # Act - result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is True - mock_queue_task.assert_not_called() - - def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): - """Test handling of queue task failures.""" - # Arrange - mock_queue_task.return_value = False - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - - -class TestSyncAccountDeletion: - """Unit tests for sync_account_deletion function.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session for testing.""" - with patch("services.enterprise.account_deletion_sync.db.session") as mock_session: - yield mock_session - - @pytest.fixture - def mock_queue_task(self): - """Mock _queue_task for testing.""" - with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: - mock_queue.return_value = True - yield mock_queue - - def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is False (community edition).""" - # Arrange - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = False - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is True - mock_db_session.query.assert_not_called() - mock_queue_task.assert_not_called() - - def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task): - """Test sync for account with multiple workspace memberships.""" - # Arrange - account_id = "acc-123" - - # Mock workspace joins - mock_join1 = MagicMock() - mock_join1.tenant_id = "tenant-1" - mock_join2 = MagicMock() - mock_join2.tenant_id = "tenant-2" - mock_join3 = MagicMock() - mock_join3.tenant_id = "tenant-3" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] - mock_db_session.query.return_value = mock_query - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id=account_id, source="account_deleted") - - # Assert - assert result is True - assert mock_queue_task.call_count == 3 - - # Verify each workspace was queued - mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted") - mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted") - mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted") - - def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task): - """Test sync for account with no workspace memberships.""" - # Arrange - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [] - mock_db_session.query.return_value = mock_query - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is True - mock_queue_task.assert_not_called() - - def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task): - """Test sync when some tasks fail to queue.""" - # Arrange - account_id = "acc-123" - - # Mock workspace joins - mock_join1 = MagicMock() - mock_join1.tenant_id = "tenant-1" - mock_join2 = MagicMock() - mock_join2.tenant_id = "tenant-2" - mock_join3 = MagicMock() - mock_join3.tenant_id = "tenant-3" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] - mock_db_session.query.return_value = mock_query - - # Mock queue_task to fail for second workspace - def queue_side_effect(workspace_id, member_id, source): - return workspace_id != "tenant-2" - - mock_queue_task.side_effect = queue_side_effect - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id=account_id, source="account_deleted") - - # Assert - assert result is False # Should return False if any task fails - assert mock_queue_task.call_count == 3 - - def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task): - """Test sync when all tasks fail to queue.""" - # Arrange - mock_join = MagicMock() - mock_join.tenant_id = "tenant-1" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join] - mock_db_session.query.return_value = mock_query - - mock_queue_task.return_value = False - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is False - mock_queue_task.assert_called_once() diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index afc3b29fca..a8ef35a0d0 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -545,7 +545,7 @@ class TestExternalDatasetServiceProcessExternalApi: params={}, ) - from dify_graph.nodes.http_request.exc import InvalidHttpMethodError + from graphon.nodes.http_request.exc import InvalidHttpMethodError with pytest.raises(InvalidHttpMethodError): ExternalDatasetService.process_external_api(settings, files=None) diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py index 27df4556bc..6511385000 100644 --- a/api/tests/unit_tests/services/plugin/test_oauth_service.py +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -13,6 +13,10 @@ import pytest from services.plugin.oauth_service import OAuthProxyService +def _oauth_proxy_setex_calls(redis_client) -> list: + return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")] + + class TestCreateProxyContext: def test_stores_context_in_redis_with_ttl(self): context_id = OAuthProxyService.create_proxy_context( @@ -22,8 +26,9 @@ class TestCreateProxyContext: assert context_id # non-empty UUID string from extensions.ext_redis import redis_client - redis_client.setex.assert_called_once() - call_args = redis_client.setex.call_args + oauth_calls = _oauth_proxy_setex_calls(redis_client) + assert len(oauth_calls) == 1 + call_args = oauth_calls[0] key = call_args[0][0] ttl = call_args[0][1] stored_data = json.loads(call_args[0][2]) diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py deleted file mode 100644 index 5d21665f75..0000000000 --- a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py +++ /dev/null @@ -1,145 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval -from services.recommend_app.recommend_app_type import RecommendAppType - - -class TestDatabaseRecommendAppRetrieval: - def test_get_type(self): - assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE - - def test_get_recommended_apps_delegates(self): - with patch.object( - DatabaseRecommendAppRetrieval, - "fetch_recommended_apps_from_db", - return_value={"recommended_apps": [], "categories": []}, - ) as mock_fetch: - result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") - mock_fetch.assert_called_once_with("en-US") - assert result == {"recommended_apps": [], "categories": []} - - def test_get_recommend_app_detail_delegates(self): - with patch.object( - DatabaseRecommendAppRetrieval, - "fetch_recommended_app_detail_from_db", - return_value={"id": "app-1"}, - ) as mock_fetch: - result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") - mock_fetch.assert_called_once_with("app-1") - assert result == {"id": "app-1"} - - -class TestFetchRecommendedAppsFromDb: - def _make_recommended_app(self, app_id, category, is_public=True, has_site=True): - site = ( - SimpleNamespace( - description="desc", - copyright="copy", - privacy_policy="pp", - custom_disclaimer="cd", - ) - if has_site - else None - ) - app = ( - SimpleNamespace(is_public=is_public, site=site) - if is_public - else SimpleNamespace(is_public=False, site=site) - ) - return SimpleNamespace( - id=f"rec-{app_id}", - app=app, - app_id=app_id, - category=category, - position=1, - is_listed=True, - ) - - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_apps_and_sorted_categories(self, mock_db): - rec1 = self._make_recommended_app("a1", "writing") - rec2 = self._make_recommended_app("a2", "assistant") - mock_db.session.scalars.return_value.all.return_value = [rec1, rec2] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert len(result["recommended_apps"]) == 2 - assert result["categories"] == ["assistant", "writing"] - - @patch("services.recommend_app.database.database_retrieval.db") - def test_falls_back_to_default_language_when_empty(self, mock_db): - mock_db.session.scalars.return_value.all.side_effect = [ - [], - [self._make_recommended_app("a1", "chat")], - ] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") - - assert len(result["recommended_apps"]) == 1 - assert mock_db.session.scalars.call_count == 2 - - @patch("services.recommend_app.database.database_retrieval.db") - def test_skips_non_public_apps(self, mock_db): - rec = self._make_recommended_app("a1", "chat", is_public=False) - mock_db.session.scalars.return_value.all.return_value = [rec] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert result["recommended_apps"] == [] - - @patch("services.recommend_app.database.database_retrieval.db") - def test_skips_apps_without_site(self, mock_db): - rec = self._make_recommended_app("a1", "chat", has_site=False) - mock_db.session.scalars.return_value.all.return_value = [rec] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert result["recommended_apps"] == [] - - -class TestFetchRecommendedAppDetailFromDb: - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_none_when_not_listed(self, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result is None - - @patch("services.recommend_app.database.database_retrieval.AppDslService") - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_none_when_app_not_public(self, mock_db, mock_dsl): - rec_chain = MagicMock() - rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") - app_chain = MagicMock() - app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False) - mock_db.session.query.side_effect = [rec_chain, app_chain] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result is None - - @patch("services.recommend_app.database.database_retrieval.AppDslService") - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_detail_on_success(self, mock_db, mock_dsl): - app_model = SimpleNamespace( - id="app-1", - name="My App", - icon="icon.png", - icon_background="#fff", - mode="chat", - is_public=True, - ) - rec_chain = MagicMock() - rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") - app_chain = MagicMock() - app_chain.where.return_value.first.return_value = app_model - mock_db.session.query.side_effect = [rec_chain, app_chain] - mock_dsl.export_dsl.return_value = "exported_yaml" - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result["id"] == "app-1" - assert result["name"] == "My App" - assert result["export_data"] == "exported_yaml" diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py deleted file mode 100644 index 9fe153c153..0000000000 --- a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,216 +0,0 @@ -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy.orm import Session - -from models.workflow import WorkflowRun -from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult - - -class TestArchivedWorkflowRunDeletion: - @pytest.fixture - def mock_db(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - @pytest.fixture - def mock_sessionmaker(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - mock_session = MagicMock(spec=Session) - mock_sm.return_value.return_value.__enter__.return_value = mock_session - yield mock_sm, mock_session - - @pytest.fixture - def mock_workflow_run_repo(self): - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - yield mock_repo - - def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - tenant_id = "tenant-456" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_run.tenant_id = tenant_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [run_id] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) - mock_delete_run.return_value = expected_result - - result = deletion.delete_by_run_id(run_id) - - assert result == expected_result - mock_session.get.assert_called_once_with(WorkflowRun, run_id) - mock_repo.get_archived_run_ids.assert_called_once() - mock_delete_run.assert_called_once_with(mock_run) - - def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - mock_session.get.return_value = None - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo"): - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "not found" in result.error - assert result.run_id == run_id - - def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [] - - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "is not archived" in result.error - - def test_delete_batch(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - deletion = ArchivedWorkflowRunDeletion() - - mock_run1 = MagicMock(spec=WorkflowRun) - mock_run1.id = "run-1" - mock_run2 = MagicMock(spec=WorkflowRun) - mock_run2.id = "run-2" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - mock_delete_run.side_effect = [ - DeleteResult(run_id="run-1", tenant_id="t1", success=True), - DeleteResult(run_id="run-2", tenant_id="t1", success=True), - ] - - results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) - - assert len(results) == 2 - assert results[0].run_id == "run-1" - assert results[1].run_id == "run-2" - assert mock_delete_run.call_count == 2 - - def test_delete_run_dry_run(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=True) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.run_id == "run-123" - - def test_delete_run_success(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.deleted_counts == {"workflow_runs": 1} - - def test_delete_run_exception(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.side_effect = Exception("Database error") - - result = deletion._delete_run(mock_run) - - assert result.success is False - assert result.error == "Database error" - - def test_delete_trigger_logs(self): - mock_session = MagicMock(spec=Session) - run_ids = ["run-1", "run-2"] - - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - mock_repo_cls.return_value = mock_repo - mock_repo.delete_by_run_ids.return_value = 5 - - count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) - - assert count == 5 - mock_repo_cls.assert_called_once_with(mock_session) - mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) - - def test_delete_node_executions(self): - mock_session = MagicMock(spec=Session) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-1" - runs = [mock_run] - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - mock_repo.delete_by_runs.return_value = (1, 2) - - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) - - assert result == (1, 2) - mock_create_repo.assert_called_once() - mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) - - def test_get_workflow_run_repo(self, mock_db): - deletion = ArchivedWorkflowRunDeletion() - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - - # First call - repo1 = deletion._get_workflow_run_repo() - assert repo1 == mock_repo - assert deletion.workflow_run_repo == mock_repo - - # Second call (should return cached) - repo2 = deletion._get_workflow_run_repo() - assert repo2 == mock_repo - mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b5..f0a66a00d4 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,8 +2,10 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import SegmentType from services.dataset_service import SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -77,7 +79,7 @@ class SegmentTestDataFactory: chunk.word_count = word_count chunk.index_node_id = f"node-{chunk_id}" chunk.index_node_hash = "hash-123" - chunk.type = "automatic" + chunk.type = SegmentType.AUTOMATIC chunk.created_by = "user-123" chunk.updated_by = None chunk.updated_at = None @@ -90,7 +92,7 @@ class SegmentTestDataFactory: document_id: str = "doc-123", dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, word_count: int = 100, **kwargs, ) -> Mock: @@ -109,7 +111,7 @@ class SegmentTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model: str = "text-embedding-ada-002", embedding_model_provider: str = "openai", **kwargs, @@ -161,7 +163,7 @@ class TestSegmentServiceCreateSegment: """Test successful creation of a segment.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test", "segment"]} mock_query = MagicMock() @@ -209,8 +211,8 @@ class TestSegmentServiceCreateSegment: def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): """Test creation of segment with QA model (requires answer).""" # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} mock_query = MagicMock() @@ -245,7 +247,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with high quality indexing technique.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -265,7 +267,7 @@ class TestSegmentServiceCreateSegment: patch( "services.dataset_service.VectorService.create_segments_vector", autospec=True ) as mock_vector_service, - patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class, + patch("services.dataset_service.ModelManager.for_tenant", autospec=True) as mock_model_manager_class, patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): @@ -287,7 +289,7 @@ class TestSegmentServiceCreateSegment: """Test segment creation when vector indexing fails.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -340,7 +342,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment @@ -428,8 +430,8 @@ class TestSegmentServiceUpdateSegment: """Test update segment with QA model (includes answer).""" # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py deleted file mode 100644 index a6bc79e82b..0000000000 --- a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -Unit tests for services.advanced_prompt_template_service -""" - -import copy - -from core.prompt.prompt_templates.advanced_prompt_templates import ( - BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_CONTEXT, - CHAT_APP_CHAT_PROMPT_CONFIG, - CHAT_APP_COMPLETION_PROMPT_CONFIG, - COMPLETION_APP_CHAT_PROMPT_CONFIG, - COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - CONTEXT, -) -from models.model import AppMode -from services.advanced_prompt_template_service import AdvancedPromptTemplateService - - -class TestAdvancedPromptTemplateService: - """Test suite for AdvancedPromptTemplateService.""" - - def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: - """Test baichuan model names use baichuan context prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "chat", - "model_name": "Baichuan2-13B", - "has_context": "true", - } - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - - def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: - """Test non-baichuan model names use common prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "completion", - "model_name": "gpt-4", - "has_context": "false", - } - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result == original_config - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: - """Test invalid app mode returns empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "chat" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} - - def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: - """Test context is prepended for completion prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: - """Test context is prepended for chat prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) - assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: - """Test chat prompt remains unchanged when has_context is false.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") - - # Assert - assert result == original_config - assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: - """Test completion app mode with completion model returns completion prompt.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") - - # Assert - assert result == original_config - assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: - """Test invalid model mode returns empty dict.""" - # Arrange - app_mode = AppMode.CHAT - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") - - # Assert - assert result == {} - - def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps completion prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] - - # Act - result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"] == original_text - - def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps chat prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] - - # Act - result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text - - def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: - """Test baichuan chat/completion returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: - """Test baichuan completion/chat returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: - """Test baichuan completion/completion prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: - """Test baichuan chat/chat prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: - """Test invalid baichuan mode combinations return empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py deleted file mode 100644 index 7ce3d7ef7b..0000000000 --- a/api/tests/unit_tests/services/test_agent_service.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -Unit tests for services.agent_service -""" - -from collections.abc import Callable -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -import pytz - -from core.plugin.impl.exc import PluginDaemonClientSideError -from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought -from services.agent_service import AgentService - - -def _make_current_user_account(timezone: str = "UTC") -> Account: - account = Account(name="Test User", email="test@example.com") - account.timezone = timezone - return account - - -def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: - app_model = MagicMock(spec=App) - app_model.id = "app-123" - app_model.tenant_id = "tenant-123" - app_model.app_model_config = app_model_config - return app_model - - -def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: - conversation = MagicMock(spec=Conversation) - conversation.id = "conv-123" - conversation.app_id = "app-123" - conversation.from_end_user_id = from_end_user_id - conversation.from_account_id = from_account_id - return conversation - - -def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: - message = MagicMock(spec=Message) - message.id = "msg-123" - message.conversation_id = "conv-123" - message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - message.provider_response_latency = 1.23 - message.answer_tokens = 4 - message.message_tokens = 6 - message.agent_thoughts = agent_thoughts - message.message_files = ["file-a.txt"] - return message - - -def _make_agent_thought() -> MagicMock: - agent_thought = MagicMock(spec=MessageAgentThought) - agent_thought.tokens = 3 - agent_thought.tool_input = "raw-input" - agent_thought.observation = "raw-output" - agent_thought.thought = "thinking" - agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - agent_thought.files = [] - agent_thought.tools = ["tool_a", "dataset_tool"] - agent_thought.tool_labels = {"tool_a": "Tool A"} - agent_thought.tool_meta = { - "tool_a": { - "tool_config": { - "tool_provider_type": "custom", - "tool_provider": "provider-1", - }, - "tool_parameters": {"param": "value"}, - "time_cost": 2.5, - }, - "dataset_tool": { - "tool_config": { - "tool_provider_type": "dataset-retrieval", - "tool_provider": "dataset-provider", - } - }, - } - agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} - agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} - return agent_thought - - -def _build_query_side_effect( - conversation: Conversation | None, - message: Message | None, - executor: EndUser | Account | None, -) -> Callable[..., MagicMock]: - def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: - query = MagicMock() - query.where.return_value = query - if any(arg is Conversation for arg in args): - query.first.return_value = conversation - elif any(arg is Message for arg in args): - query.first.return_value = message - elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): - query.first.return_value = executor - return query - - return _query_side_effect - - -class TestAgentServiceGetAgentLogs: - """Test suite for AgentService.get_agent_logs.""" - - def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: - """Test missing conversation raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - with patch("services.agent_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") - - def test_get_agent_logs_should_raise_when_message_missing(self) -> None: - """Test missing message raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - with patch("services.agent_service.db") as mock_db: - conversation_query = MagicMock() - conversation_query.where.return_value = conversation_query - conversation_query.first.return_value = conversation - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = None - - mock_db.session.query.side_effect = [conversation_query, message_query] - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") - - def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: - """Test missing app model config raises ValueError.""" - # Arrange - app_model = _make_app_model(None) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: - """Test missing agent config raises ValueError.""" - # Arrange - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: - """Test agent logs returned for end-user executor with tool icons.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - executor = MagicMock(spec=EndUser) - executor.name = "End User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_tool = MagicMock() - agent_tool.tool_name = "tool_a" - agent_tool.provider_type = "custom" - agent_tool.provider_id = "provider-2" - agent_config = MagicMock() - agent_config.tools = [agent_tool] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, - patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - mock_get_icon.side_effect = [None, "icon-a"] - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["status"] == "success" - assert result["meta"]["executor"] == "End User" - assert result["meta"]["total_tokens"] == 10 - assert result["meta"]["agent_mode"] == "react" - assert result["meta"]["iterations"] == 1 - assert result["files"] == ["file-a.txt"] - assert len(result["iterations"]) == 1 - tool_calls = result["iterations"][0]["tool_calls"] - assert tool_calls[0]["tool_name"] == "tool_a" - assert tool_calls[0]["tool_icon"] == "icon-a" - assert tool_calls[1]["tool_name"] == "dataset_tool" - assert tool_calls[1]["tool_icon"] == "" - mock_convert.assert_called_once() - - def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: - """Test agent logs fall back to account executor when end user is missing.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") - executor = MagicMock(spec=Account) - executor.name = "Account User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Account User" - - def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: - """Test unknown executor and missing tool details fall back to defaults.""" - # Arrange - agent_thought = _make_agent_thought() - agent_thought.tool_labels = {} - agent_thought.tool_inputs_dict = {} - agent_thought.tool_outputs_dict = None - agent_thought.tool_meta = {"tool_a": {"error": "failed"}} - agent_thought.tools = ["tool_a"] - - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Unknown" - assert result["meta"]["agent_mode"] == "react" - tool_call = result["iterations"][0]["tool_calls"][0] - assert tool_call["status"] == "error" - assert tool_call["error"] == "failed" - assert tool_call["tool_label"] == "tool_a" - assert tool_call["tool_input"] == {} - assert tool_call["tool_output"] == {} - assert tool_call["time_cost"] == 0 - assert tool_call["tool_parameters"] == {} - assert tool_call["tool_icon"] is None - - -class TestAgentServiceProviders: - """Test suite for AgentService provider methods.""" - - def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: - """Test list_agent_providers delegates to PluginAgentClient.""" - # Arrange - tenant_id = "tenant-1" - expected = [{"name": "provider"}] - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_providers.return_value = expected - - # Act - result = AgentService.list_agent_providers("user-1", tenant_id) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) - - def test_get_agent_provider_should_return_provider_when_successful(self) -> None: - """Test get_agent_provider returns provider when successful.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - expected = {"name": provider_name} - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.return_value = expected - - # Act - result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) - - def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: - """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( - "plugin error" - ) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py deleted file mode 100644 index 7f4b5fdaa3..0000000000 --- a/api/tests/unit_tests/services/test_api_based_extension_service.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -Comprehensive unit tests for services/api_based_extension_service.py - -Covers: -- APIBasedExtensionService.get_all_by_tenant_id -- APIBasedExtensionService.save -- APIBasedExtensionService.delete -- APIBasedExtensionService.get_with_tenant_id -- APIBasedExtensionService._validation (new record & existing record branches) -- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from services.api_based_extension_service import APIBasedExtensionService - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_extension( - *, - id_: str | None = None, - tenant_id: str = "tenant-001", - name: str = "my-ext", - api_endpoint: str = "https://example.com/hook", - api_key: str = "secret-key-123", -) -> MagicMock: - """Return a lightweight mock that mimics APIBasedExtension.""" - ext = MagicMock() - ext.id = id_ - ext.tenant_id = tenant_id - ext.name = name - ext.api_endpoint = api_endpoint - ext.api_key = api_key - return ext - - -# --------------------------------------------------------------------------- -# Tests: get_all_by_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetAllByTenantId: - """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): - """Each api_key is decrypted and the list is returned.""" - ext1 = _make_extension(id_="id-1", api_key="enc-key-1") - ext2 = _make_extension(id_="id-2", api_key="enc-key-2") - - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ - ext1, - ext2, - ] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [ext1, ext2] - assert ext1.api_key == "decrypted-key" - assert ext2.api_key == "decrypted-key" - assert mock_decrypt.call_count == 2 - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): - """Returns an empty list gracefully when no records exist.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [] - mock_decrypt.assert_not_called() - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): - """Verifies the DB is queried with the supplied tenant_id.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") - - mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") - - -# --------------------------------------------------------------------------- -# Tests: save -# --------------------------------------------------------------------------- - - -class TestSave: - """Tests for APIBasedExtensionService.save.""" - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation") - def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): - """Happy path: validation passes, key is encrypted, record is added and committed.""" - ext = _make_extension(id_=None, api_key="plain-key-123") - - result = APIBasedExtensionService.save(ext) - - mock_validation.assert_called_once_with(ext) - mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") - assert ext.api_key == "encrypted-key" - mock_db.session.add.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - assert result is ext - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) - def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): - """If _validation raises, save should propagate the error without touching the DB.""" - ext = _make_extension(name="") - - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(ext) - - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -# --------------------------------------------------------------------------- -# Tests: delete -# --------------------------------------------------------------------------- - - -class TestDelete: - """Tests for APIBasedExtensionService.delete.""" - - @patch("services.api_based_extension_service.db") - def test_delete_removes_record_and_commits(self, mock_db): - """delete() must call session.delete with the extension and then commit.""" - ext = _make_extension(id_="delete-me") - - APIBasedExtensionService.delete(ext) - - mock_db.session.delete.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests: get_with_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetWithTenantId: - """Tests for APIBasedExtensionService.get_with_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): - """Found extension has its api_key decrypted before being returned.""" - ext = _make_extension(id_="ext-123", api_key="enc-key") - - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext - - result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") - - assert result is ext - assert ext.api_key == "decrypted-key" - mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") - - @patch("services.api_based_extension_service.db") - def test_raises_value_error_when_not_found(self, mock_db): - """Raises ValueError when no matching extension exists.""" - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None - - with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): - """Verifies both tenant_id and extension id are used in the query.""" - ext = _make_extension(id_="ext-abc") - chain = mock_db.session.query.return_value - chain.filter_by.return_value.filter_by.return_value.first.return_value = ext - - APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") - - # First filter_by call uses tenant_id - chain.filter_by.assert_called_once_with(tenant_id="tenant-002") - # Second filter_by call uses id - chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") - - -# --------------------------------------------------------------------------- -# Tests: _validation (new record — id is falsy) -# --------------------------------------------------------------------------- - - -class TestValidationNewRecord: - """Tests for _validation() with a brand-new record (no id).""" - - def _build_mock_db(self, name_exists: bool = False): - mock_db = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() if name_exists else None - ) - return mock_db - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_new_extension_passes(self, mock_db, mock_ping): - """A new record with all valid fields should pass without exceptions.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_empty(self, mock_db): - """Empty name raises ValueError.""" - ext = _make_extension(id_=None, name="") - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_none(self, mock_db): - """None name raises ValueError.""" - ext = _make_extension(id_=None, name=None) - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_already_exists_for_new_record(self, mock_db): - """A new record whose name already exists raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() - ) - ext = _make_extension(id_=None, name="duplicate-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_empty(self, mock_db): - """Empty api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint="") - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_none(self, mock_db): - """None api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint=None) - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_empty(self, mock_db): - """Empty api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="") - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_none(self, mock_db): - """None api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key=None) - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_too_short(self, mock_db): - """api_key shorter than 5 characters raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="abc") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_exactly_four_chars(self, mock_db): - """api_key with exactly 4 characters raises ValueError (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="1234") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): - """api_key with exactly 5 characters should pass (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="12345") - - # Should not raise - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _validation (existing record — id is truthy) -# --------------------------------------------------------------------------- - - -class TestValidationExistingRecord: - """Tests for _validation() with an existing record (id is set).""" - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_existing_extension_passes(self, mock_db, mock_ping): - """An existing record whose name is unique (excluding self) should pass.""" - # .where(...).first() → None means no *other* record has that name - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = None - ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): - """Existing record cannot use a name already owned by a different record.""" - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = MagicMock() - ext = _make_extension(id_="existing-id", name="taken-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _ping_connection -# --------------------------------------------------------------------------- - - -class TestPingConnection: - """Tests for APIBasedExtensionService._ping_connection.""" - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_successful_ping_returns_pong(self, mock_requestor_class): - """When the endpoint returns {"result": "pong"}, no exception is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") - # Should not raise - APIBasedExtensionService._ping_connection(ext) - - mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): - """When the response is not {"result": "pong"}, a ValueError is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "error"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_network_exception_wraps_in_value_error(self, mock_requestor_class): - """Any exception raised during request is wrapped in a ValueError.""" - mock_client = MagicMock() - mock_client.request.side_effect = ConnectionError("network failure") - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: network failure"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): - """Exception raised by the requestor constructor itself is wrapped.""" - mock_requestor_class.side_effect = RuntimeError("bad config") - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: bad config"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_missing_result_key_raises_value_error(self, mock_requestor_class): - """A response dict without a 'result' key does not equal 'pong' → raises.""" - mock_client = MagicMock() - mock_client.request.return_value = {} # no 'result' key - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_uses_ping_extension_point(self, mock_requestor_class): - """The PING extension point is passed to the client.request call.""" - from models.api_based_extension import APIBasedExtensionPoint - - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - APIBasedExtensionService._ping_connection(ext) - - call_kwargs = mock_client.request.call_args - assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING - assert call_kwargs.kwargs["params"] == {} diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index 4f7d184046..afea8ec92a 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -10,7 +10,7 @@ from core.trigger.constants import ( TRIGGER_SCHEDULE_NODE_TYPE, TRIGGER_WEBHOOK_NODE_TYPE, ) -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service @@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch def test_import_app_pending_stores_import_info_in_redis(): service = AppDslService(MagicMock()) + app_dsl_service.redis_client.setex.reset_mock() result = service.import_app( account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, @@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch): created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + app_dsl_service.redis_client.delete.reset_mock() result = service.confirm_import(import_id="import-1", account=_account_mock()) assert result.status == ImportStatus.COMPLETED assert result.app_id == "confirmed-app" - app_dsl_service.redis_client.delete.assert_called_once() + app_dsl_service.redis_client.delete.assert_called_once_with( + f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1" + ) def test_confirm_import_exception_returns_failed(monkeypatch): diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py deleted file mode 100644 index bff8dc92c6..0000000000 --- a/api/tests/unit_tests/services/test_app_service.py +++ /dev/null @@ -1,609 +0,0 @@ -"""Unit tests for services.app_service.""" - -import json -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock, patch - -import pytest - -from core.errors.error import ProviderTokenNotInitError -from models import Account, Tenant -from models.model import App, AppMode -from services.app_service import AppService - - -@pytest.fixture -def service() -> AppService: - """Provide AppService instance.""" - return AppService() - - -@pytest.fixture -def account() -> Account: - """Create account object for create_app tests.""" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - result = Account(name="Account User", email="account@example.com") - result.id = "acc-1" - result._current_tenant = tenant - return result - - -@pytest.fixture -def default_args() -> dict: - """Create default create_app args.""" - return { - "name": "Test App", - "mode": AppMode.CHAT.value, - "icon": "🤖", - "icon_background": "#FFFFFF", - } - - -@pytest.fixture -def app_template() -> dict: - """Create basic app template for create_app tests.""" - return { - AppMode.CHAT: { - "app": {}, - "model_config": { - "model": { - "provider": "provider-a", - "name": "model-a", - "mode": "chat", - "completion_params": {}, - } - }, - } - } - - -def _make_current_user() -> Account: - user = Account(name="Tester", email="tester@example.com") - user.id = "user-1" - tenant = Tenant(name="Tenant") - tenant.id = "tenant-1" - user._current_tenant = tenant - return user - - -class TestAppServicePagination: - """Test suite for get_paginate_apps.""" - - def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: - """Test pagination returns None when tag filter has no targets.""" - # Arrange - args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} - - with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is None - - def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: - """Test pagination delegates to db.paginate when filters are valid.""" - # Arrange - args = { - "mode": "workflow", - "is_created_by_me": True, - "name": "My_App%", - "tag_ids": ["tag-1"], - "page": 2, - "limit": 10, - } - expected_pagination = MagicMock() - - with ( - patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), - patch("libs.helper.escape_like_pattern", return_value="escaped"), - patch("services.app_service.db") as mock_db, - ): - mock_db.paginate.return_value = expected_pagination - - # Act - result = service.get_paginate_apps("user-1", "tenant-1", args) - - # Assert - assert result is expected_pagination - mock_db.paginate.assert_called_once() - - -class TestAppServiceCreate: - """Test suite for create_app.""" - - def test_create_app_should_create_with_matching_default_model( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app uses matching default model and persists app config.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - model_instance = SimpleNamespace( - model_name="model-a", - provider="provider-a", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) - mock_config.BILLING_ENABLED = True - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - assert app_instance.app_model_config_id == "cfg-1" - mock_db.session.add.assert_any_call(app_instance) - mock_db.session.add.assert_any_call(app_model_config) - assert mock_db.session.flush.call_count == 2 - mock_db.session.commit.assert_called_once() - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - - def test_create_app_should_raise_when_model_schema_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app raises ValueError when non-matching model has no schema.""" - # Arrange - app_instance = SimpleNamespace(id="app-1") - model_instance = SimpleNamespace( - model_name="model-b", - provider="provider-b", - model_type_instance=MagicMock(), - credentials={"k": "v"}, - ) - model_instance.model_type_instance.get_model_schema.return_value = None - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.return_value = model_instance - - # Act & Assert - with pytest.raises(ValueError, match="model schema not found"): - service.create_app("tenant-1", default_args, account) - mock_db.session.commit.assert_not_called() - - def test_create_app_should_fallback_to_default_provider_when_model_missing( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test create_app falls back to provider/model name when no default model instance is available.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db") as mock_db, - patch("services.app_service.app_was_created") as mock_event, - patch("services.app_service.FeatureService.get_system_features") as mock_features, - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch("services.app_service.dify_config") as mock_config, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) - mock_config.BILLING_ENABLED = False - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_event.send.assert_called_once_with(app_instance, account=account) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") - - def test_create_app_should_log_and_fallback_on_unexpected_model_error( - self, - service: AppService, - account: Account, - default_args: dict, - app_template: dict, - ) -> None: - """Test unexpected model manager errors are logged and fallback provider is used.""" - # Arrange - app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") - app_model_config = SimpleNamespace(id="cfg-1") - - with ( - patch("services.app_service.default_app_templates", app_template), - patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.app_service.db"), - patch("services.app_service.app_was_created"), - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), - ), - patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), - patch("services.app_service.logger") as mock_logger, - ): - manager = mock_model_manager.return_value - manager.get_default_model_instance.side_effect = RuntimeError("boom") - manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") - - # Act - result = service.create_app("tenant-1", default_args, account) - - # Assert - assert result is app_instance - mock_logger.exception.assert_called_once() - - -class TestAppServiceGetAndUpdate: - """Test suite for app retrieval and update methods.""" - - def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: - """Test get_app returns original app for non-agent modes.""" - # Arrange - app = MagicMock() - app.mode = AppMode.CHAT - app.is_agent = False - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: - """Test get_app returns app when agent mode has no model config.""" - # Arrange - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = None - - with patch("services.app_service.current_user", _make_current_user()): - # Act - result = service.get_app(app) - - # Assert - assert result is app - - def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: - """Test get_app decrypts and masks secret tool parameters.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - manager = MagicMock() - manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} - manager.mask_tool_parameters.return_value = {"secret": "***"} - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), - patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - assert tool["tool_parameters"] == {"secret": "***"} - assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} - - def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: - """Test get_app logs and continues when masking fails.""" - # Arrange - tool = { - "provider_type": "builtin", - "provider_id": "provider-1", - "tool_name": "tool-a", - "tool_parameters": {"secret": "encrypted"}, - "extra": True, - } - model_config = MagicMock() - model_config.agent_mode_dict = {"tools": [tool]} - - app = MagicMock() - app.id = "app-1" - app.mode = AppMode.AGENT_CHAT - app.is_agent = False - app.app_model_config = model_config - - with ( - patch("services.app_service.current_user", _make_current_user()), - patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), - patch("services.app_service.logger") as mock_logger, - ): - # Act - result = service.get_app(app) - - # Assert - assert result.app_model_config is model_config - mock_logger.exception.assert_called_once() - - def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: - """Test update methods set fields and commit changes.""" - # Arrange - app = cast( - App, - SimpleNamespace( - name="old", - description="old", - icon_type="emoji", - icon="a", - icon_background="#111", - enable_site=True, - enable_api=True, - ), - ) - args = { - "name": "new", - "description": "new-desc", - "icon_type": "image", - "icon": "new-icon", - "icon_background": "#222", - "use_icon_as_answer_icon": True, - "max_active_requests": 5, - } - user = SimpleNamespace(id="user-1") - - with ( - patch("services.app_service.current_user", user), - patch("services.app_service.db") as mock_db, - patch("services.app_service.naive_utc_now", return_value="now"), - ): - # Act - updated = service.update_app(app, args) - renamed = service.update_app_name(app, "rename") - iconed = service.update_app_icon(app, "icon-2", "#333") - site_same = service.update_app_site_status(app, app.enable_site) - api_same = service.update_app_api_status(app, app.enable_api) - site_changed = service.update_app_site_status(app, False) - api_changed = service.update_app_api_status(app, False) - - # Assert - assert updated is app - assert renamed is app - assert iconed is app - assert site_same is app - assert api_same is app - assert site_changed is app - assert api_changed is app - assert mock_db.session.commit.call_count >= 5 - - -class TestAppServiceDeleteAndMeta: - """Test suite for delete and metadata methods.""" - - def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: - """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" - # Arrange - app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) - - with ( - patch("services.app_service.db") as mock_db, - patch( - "services.app_service.FeatureService.get_system_features", - return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), - ), - patch("services.app_service.EnterpriseService") as mock_enterprise, - patch( - "services.app_service.dify_config", - new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), - ), - patch("services.app_service.BillingService") as mock_billing, - patch("services.app_service.remove_app_and_related_data_task") as mock_task, - ): - # Act - service.delete_app(app) - - # Assert - mock_db.session.delete.assert_called_once_with(app) - mock_db.session.commit.assert_called_once() - mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") - mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") - mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") - - def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: - """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" - # Arrange - workflow = SimpleNamespace( - graph_dict={ - "nodes": [ - { - "data": { - "type": "tool", - "provider_type": "builtin", - "provider_id": "builtin-provider", - "tool_name": "tool_builtin", - } - }, - { - "data": { - "type": "tool", - "provider_type": "api", - "provider_id": "api-provider-id", - "tool_name": "tool_api", - } - }, - ] - } - ) - app = cast( - App, - SimpleNamespace( - mode=AppMode.WORKFLOW.value, - workflow=workflow, - app_model_config=None, - tenant_id="tenant-1", - icon_type="emoji", - icon_background="#fff", - ), - ) - - provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = provider - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") - assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} - - def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: - """Test get_app_meta falls back to default icon when API provider lookup fails.""" - # Arrange - app_model_config = SimpleNamespace( - agent_mode_dict={ - "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] - } - ) - app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) - - with ( - patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), - patch("services.app_service.db") as mock_db, - ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act - meta = service.get_app_meta(app) - - # Assert - assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} - - def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: - """Test get_app_meta returns empty metadata when workflow/model config is absent.""" - # Arrange - workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) - chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) - - # Act - workflow_meta = service.get_app_meta(workflow_app) - chat_meta = service.get_app_meta(chat_app) - - # Assert - assert workflow_meta == {"tool_icons": {}} - assert chat_meta == {"tool_icons": {}} - - -class TestAppServiceCodeLookup: - """Test suite for app code lookup methods.""" - - def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: - """Test get_app_code_by_id raises when site is missing.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_code_by_id("app-1") - - def test_get_app_code_by_id_should_return_code(self) -> None: - """Test get_app_code_by_id returns site code.""" - # Arrange - site = SimpleNamespace(code="code-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_code_by_id("app-1") - - # Assert - assert result == "code-1" - - def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: - """Test get_app_id_by_code raises when code does not exist.""" - # Arrange - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError, match="not found"): - AppService.get_app_id_by_code("missing") - - def test_get_app_id_by_code_should_return_app_id(self) -> None: - """Test get_app_id_by_code returns linked app id.""" - # Arrange - site = SimpleNamespace(app_id="app-1") - with patch("services.app_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = site - mock_db.session.query.return_value = query - - # Act - result = AppService.get_app_id_by_code("code-1") - - # Assert - assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_attachment_service.py b/api/tests/unit_tests/services/test_attachment_service.py deleted file mode 100644 index 88be20bc41..0000000000 --- a/api/tests/unit_tests/services/test_attachment_service.py +++ /dev/null @@ -1,73 +0,0 @@ -import base64 -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from werkzeug.exceptions import NotFound - -import services.attachment_service as attachment_service_module -from models.model import UploadFile -from services.attachment_service import AttachmentService - - -class TestAttachmentService: - def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self): - """Test that AttachmentService keeps the provided sessionmaker instance.""" - session_factory = sessionmaker() - - service = AttachmentService(session_factory=session_factory) - - assert service._session_maker is session_factory - - def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self): - """Test that AttachmentService builds a sessionmaker bound to the provided engine.""" - engine = create_engine("sqlite:///:memory:") - - service = AttachmentService(session_factory=engine) - session = service._session_maker() - try: - assert session.bind == engine - finally: - session.close() - engine.dispose() - - @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) - def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory): - """Test that invalid session_factory types are rejected.""" - with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): - AttachmentService(session_factory=invalid_session_factory) - - def test_should_return_base64_encoded_blob_when_file_exists(self): - """Test that existing files are loaded from storage and returned as base64.""" - service = AttachmentService(session_factory=sessionmaker()) - upload_file = MagicMock(spec=UploadFile) - upload_file.key = "upload-file-key" - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = upload_file - service._session_maker = MagicMock(return_value=session) - - with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: - result = service.get_file_base64("file-123") - - assert result == base64.b64encode(b"binary-content").decode() - service._session_maker.assert_called_once_with(expire_on_commit=False) - session.query.assert_called_once_with(UploadFile) - mock_load.assert_called_once_with("upload-file-key") - - def test_should_raise_not_found_when_file_does_not_exist(self): - """Test that missing files raise NotFound and never call storage.""" - service = AttachmentService(session_factory=sessionmaker()) - - session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None - service._session_maker = MagicMock(return_value=session) - - with patch.object(attachment_service_module.storage, "load_once") as mock_load: - with pytest.raises(NotFound, match="File not found"): - service.get_file_base64("missing-file") - - service._session_maker.assert_called_once_with(expire_on_commit=False) - session.query.assert_called_once_with(UploadFile) - mock_load.assert_not_called() diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d67469105..175fd3ee01 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -214,7 +214,7 @@ def factory(): class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in CHAT mode.""" # Arrange @@ -237,10 +237,9 @@ class TestAudioServiceASR: # Assert assert result == {"text": "Transcribed text"} mock_model_instance.invoke_speech2text.assert_called_once() - call_args = mock_model_instance.invoke_speech2text.call_args - assert call_args.kwargs["user"] == "user-123" + mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange @@ -347,7 +346,7 @@ class TestAudioServiceASR: with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"): AudioService.transcript_asr(app_model=app, file=file) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that ASR raises error when no model instance is available.""" # Arrange @@ -370,7 +369,7 @@ class TestAudioServiceASR: class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): """Test successful TTS with text input.""" # Arrange @@ -398,15 +397,14 @@ class TestAudioServiceTTS: # Assert assert result == b"audio data" + mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") mock_model_instance.invoke_tts.assert_called_once_with( content_text="Hello world", - user="user-123", - tenant_id=app.tenant_id, voice="en-US-Neural", ) @patch("services.audio_service.db.session", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" # Arrange @@ -445,7 +443,7 @@ class TestAudioServiceTTS: assert result == b"audio from message" mock_model_instance.invoke_tts.assert_called_once() - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): """Test TTS uses default voice when none specified.""" # Arrange @@ -475,7 +473,7 @@ class TestAudioServiceTTS: call_args = mock_model_instance.invoke_tts.call_args assert call_args.kwargs["voice"] == "default-voice" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): """Test TTS gets first available voice when none is configured.""" # Arrange @@ -506,7 +504,7 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "auto-voice" @patch("services.audio_service.WorkflowService", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( self, mock_model_manager_class, mock_workflow_service_class, factory ): @@ -549,7 +547,7 @@ class TestAudioServiceTTS: with pytest.raises(ValueError, match="Text is required"): AudioService.transcript_tts(app_model=app, text=None) - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory): """Test that TTS returns None for invalid message ID format.""" # Arrange @@ -564,7 +562,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory): """Test that TTS returns None when message doesn't exist.""" # Arrange @@ -585,7 +583,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.db.session") def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory): """Test that TTS returns None when message answer is empty.""" # Arrange @@ -611,7 +609,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): """Test that TTS raises error when no voices are available.""" # Arrange @@ -637,7 +635,7 @@ class TestAudioServiceTTS: class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): """Test successful retrieval of TTS voices.""" # Arrange @@ -662,7 +660,7 @@ class TestAudioServiceTTSVoices: assert result == expected_voices mock_model_instance.get_tts_voices.assert_called_once_with(language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that TTS voices raises error when no model instance is available.""" # Arrange @@ -677,7 +675,7 @@ class TestAudioServiceTTSVoices: with pytest.raises(ProviderNotSupportTextToSpeechServiceError): AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): """Test that TTS voices propagates exceptions from model instance.""" # Arrange diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 135c2e9962..252b898c70 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations: # Assert assert result == {} + def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request): + """Test bulk plan retrieval converts string expiration_date to int.""" + # Arrange + tenant_ids = ["tenant-1"] + mock_send_request.return_value = { + "data": { + "tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"}, + } + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert "tenant-1" in result + assert isinstance(result["tenant-1"]["expiration_date"], int) + assert result["tenant-1"]["expiration_date"] == 1735689600 + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" # Arrange diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 35157790ca..1bf4c0e172 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -6,13 +6,14 @@ Tests are organized by functionality and include edge cases, error handling, and both positive and negative test scenarios. """ -from datetime import datetime, timedelta +from datetime import timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.enums import ConversationFromSource @@ -122,8 +123,8 @@ class ConversationServiceTestDataFactory: conversation.is_deleted = kwargs.get("is_deleted", False) conversation.name = kwargs.get("name", "Test Conversation") conversation.status = kwargs.get("status", "normal") - conversation.created_at = kwargs.get("created_at", datetime.utcnow()) - conversation.updated_at = kwargs.get("updated_at", datetime.utcnow()) + conversation.created_at = kwargs.get("created_at", naive_utc_now()) + conversation.updated_at = kwargs.get("updated_at", naive_utc_now()) for key, value in kwargs.items(): setattr(conversation, key, value) return conversation @@ -152,7 +153,7 @@ class ConversationServiceTestDataFactory: message.conversation_id = conversation_id message.app_id = app_id message.query = kwargs.get("query", "Test message content") - message.created_at = kwargs.get("created_at", datetime.utcnow()) + message.created_at = kwargs.get("created_at", naive_utc_now()) for key, value in kwargs.items(): setattr(message, key, value) return message @@ -181,8 +182,8 @@ class ConversationServiceTestDataFactory: variable.conversation_id = conversation_id variable.app_id = app_id variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} - variable.created_at = kwargs.get("created_at", datetime.utcnow()) - variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + variable.created_at = kwargs.get("created_at", naive_utc_now()) + variable.updated_at = kwargs.get("updated_at", naive_utc_now()) # Mock to_variable method mock_variable = Mock() @@ -302,7 +303,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.updated_at = datetime.utcnow() + mock_conversation.updated_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -323,7 +324,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.created_at = datetime.utcnow() + mock_conversation.created_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -668,9 +669,9 @@ class TestConversationServiceConversationalVariable: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( - created_at=datetime.utcnow() - timedelta(hours=1) + created_at=naive_utc_now() - timedelta(hours=1) ) - variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=naive_utc_now()) mock_session.scalar.return_value = last_variable mock_session.scalars.return_value.all.return_value = [variable] diff --git a/api/tests/unit_tests/services/test_conversation_variable_updater.py b/api/tests/unit_tests/services/test_conversation_variable_updater.py deleted file mode 100644 index 20f7caa78e..0000000000 --- a/api/tests/unit_tests/services/test_conversation_variable_updater.py +++ /dev/null @@ -1,75 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from dify_graph.variables import StringVariable -from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater - - -class TestConversationVariableUpdater: - def test_should_update_conversation_variable_data_and_commit(self): - """Test update persists serialized variable data when the row exists.""" - conversation_id = "conv-123" - variable = StringVariable( - id="var-123", - name="topic", - value="new value", - ) - expected_json = variable.model_dump_json() - - row = SimpleNamespace(data="old value") - session = MagicMock() - session.scalar.return_value = row - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - session_maker = MagicMock(return_value=session_context) - updater = ConversationVariableUpdater(session_maker) - - updater.update(conversation_id=conversation_id, variable=variable) - - session_maker.assert_called_once_with() - session.scalar.assert_called_once() - stmt = session.scalar.call_args.args[0] - compiled_params = stmt.compile().params - assert variable.id in compiled_params.values() - assert conversation_id in compiled_params.values() - assert row.data == expected_json - session.commit.assert_called_once() - - def test_should_raise_not_found_error_when_conversation_variable_missing(self): - """Test update raises ConversationVariableNotFoundError when no matching row exists.""" - conversation_id = "conv-404" - variable = StringVariable( - id="var-404", - name="topic", - value="value", - ) - - session = MagicMock() - session.scalar.return_value = None - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - session_maker = MagicMock(return_value=session_context) - updater = ConversationVariableUpdater(session_maker) - - with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): - updater.update(conversation_id=conversation_id, variable=variable) - - session.commit.assert_not_called() - - def test_should_do_nothing_when_flush_is_called(self): - """Test flush currently behaves as a no-op and returns None.""" - session_maker = MagicMock() - updater = ConversationVariableUpdater(session_maker) - - result = updater.flush() - - assert result is None - session_maker.assert_not_called() diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py deleted file mode 100644 index 9ef314cb9e..0000000000 --- a/api/tests/unit_tests/services/test_credit_pool_service.py +++ /dev/null @@ -1,157 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest - -import services.credit_pool_service as credit_pool_service_module -from core.errors.error import QuotaExceededError -from models import TenantCreditPool -from services.credit_pool_service import CreditPoolService - - -@pytest.fixture -def mock_credit_deduction_setup(): - """Fixture providing common setup for credit deduction tests.""" - pool = SimpleNamespace(remaining_credits=50) - fake_engine = MagicMock() - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool) - mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine)) - mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context) - - return { - "pool": pool, - "fake_engine": fake_engine, - "session": session, - "session_context": session_context, - "patches": (mock_get_pool, mock_db, mock_session), - } - - -class TestCreditPoolService: - def test_should_create_default_pool_with_trial_type_and_configured_quota(self): - """Test create_default_pool persists a trial pool using configured hosted credits.""" - tenant_id = "tenant-123" - hosted_pool_credits = 5000 - - with ( - patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits), - patch.object(credit_pool_service_module, "db") as mock_db, - ): - pool = CreditPoolService.create_default_pool(tenant_id) - - assert isinstance(pool, TenantCreditPool) - assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" - assert pool.quota_limit == hosted_pool_credits - assert pool.quota_used == 0 - mock_db.session.add.assert_called_once_with(pool) - mock_db.session.commit.assert_called_once() - - def test_should_return_first_pool_from_query_when_get_pool_called(self): - """Test get_pool queries by tenant and pool_type and returns first result.""" - tenant_id = "tenant-123" - pool_type = "enterprise" - expected_pool = MagicMock(spec=TenantCreditPool) - - with patch.object(credit_pool_service_module, "db") as mock_db: - query = mock_db.session.query.return_value - filtered_query = query.filter_by.return_value - filtered_query.first.return_value = expected_pool - - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type) - - assert result == expected_pool - mock_db.session.query.assert_called_once_with(TenantCreditPool) - query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type) - filtered_query.first.assert_called_once() - - def test_should_return_false_when_pool_not_found_in_check_credits_available(self): - """Test check_credits_available returns False when tenant has no pool.""" - with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool: - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10) - - assert result is False - mock_get_pool.assert_called_once_with("tenant-123", "trial") - - def test_should_return_true_when_remaining_credits_cover_required_amount(self): - """Test check_credits_available returns True when remaining credits are sufficient.""" - pool = SimpleNamespace(remaining_credits=100) - - with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool: - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) - - assert result is True - mock_get_pool.assert_called_once_with("tenant-123", "trial") - - def test_should_return_false_when_remaining_credits_are_insufficient(self): - """Test check_credits_available returns False when required credits exceed remaining credits.""" - pool = SimpleNamespace(remaining_credits=30) - - with patch.object(CreditPoolService, "get_pool", return_value=pool): - result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) - - assert result is False - - def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self): - """Test check_and_deduct_credits raises when tenant credit pool does not exist.""" - with patch.object(CreditPoolService, "get_pool", return_value=None): - with pytest.raises(QuotaExceededError, match="Credit pool not found"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self): - """Test check_and_deduct_credits raises when remaining credits are zero or negative.""" - pool = SimpleNamespace(remaining_credits=0) - - with patch.object(CreditPoolService, "get_pool", return_value=pool): - with pytest.raises(QuotaExceededError, match="No credits remaining"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup): - """Test check_and_deduct_credits updates quota_used by the actual deducted amount.""" - tenant_id = "tenant-123" - pool_type = "trial" - credits_required = 200 - remaining_credits = 120 - expected_deducted_credits = 120 - - mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits - patches = mock_credit_deduction_setup["patches"] - session = mock_credit_deduction_setup["session"] - - with patches[0], patches[1], patches[2]: - result = CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=credits_required, - pool_type=pool_type, - ) - - assert result == expected_deducted_credits - session.execute.assert_called_once() - session.commit.assert_called_once() - - stmt = session.execute.call_args.args[0] - compiled_params = stmt.compile().params - assert tenant_id in compiled_params.values() - assert pool_type in compiled_params.values() - assert expected_deducted_credits in compiled_params.values() - - def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup): - """Test check_and_deduct_credits translates DB update failures to QuotaExceededError.""" - mock_credit_deduction_setup["pool"].remaining_credits = 50 - mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure") - session = mock_credit_deduction_setup["session"] - - patches = mock_credit_deduction_setup["patches"] - mock_logger = patch.object(credit_pool_service_module, "logger") - - with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj: - with pytest.raises(QuotaExceededError, match="Failed to deduct credits"): - CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) - - session.commit.assert_not_called() - mock_logger_obj.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_permission.py b/api/tests/unit_tests/services/test_dataset_permission.py deleted file mode 100644 index 4974d6c1ef..0000000000 --- a/api/tests/unit_tests/services/test_dataset_permission.py +++ /dev/null @@ -1,305 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum -from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError - - -class DatasetPermissionTestDataFactory: - """Factory class for creating test data and mock objects for dataset permission tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.permission = permission - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_dataset_permission_mock( - dataset_id: str = "dataset-123", - account_id: str = "user-789", - **kwargs, - ) -> Mock: - """Create a mock dataset permission record.""" - permission = Mock(spec=DatasetPermission) - permission.dataset_id = dataset_id - permission.account_id = account_id - for key, value in kwargs.items(): - setattr(permission, key, value) - return permission - - -class TestDatasetPermissionService: - """ - Comprehensive unit tests for DatasetService.check_dataset_permission method. - - This test suite covers all permission scenarios including: - - Cross-tenant access restrictions - - Owner privilege checks - - Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) - - Explicit permission checks for PARTIAL_TEAM - - Error conditions and logging - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with patch("services.dataset_service.db.session") as mock_session: - yield { - "db_session": mock_session, - } - - @pytest.fixture - def mock_logging_dependencies(self): - """Mock setup for logging tests.""" - with patch("services.dataset_service.logger") as mock_logging: - yield { - "logging": mock_logging, - } - - def _assert_permission_check_passes(self, dataset: Mock, user: Mock): - """Helper method to verify that permission check passes without raising exceptions.""" - # Should not raise any exception - DatasetService.check_dataset_permission(dataset, user) - - def _assert_permission_check_fails( - self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset." - ): - """Helper method to verify that permission check fails with expected error.""" - with pytest.raises(NoPermissionError, match=expected_message): - DatasetService.check_dataset_permission(dataset, user) - - def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str): - """Helper method to verify database query calls for permission checks.""" - mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id) - - def _assert_database_query_not_called(self, mock_session: Mock): - """Helper method to verify that database query was not called.""" - mock_session.query.assert_not_called() - - # ==================== Cross-Tenant Access Tests ==================== - - def test_permission_check_different_tenant_should_fail(self): - """Test that users from different tenants cannot access dataset regardless of other permissions.""" - # Create dataset and user from different tenants - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM - ) - user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR - ) - - # Should fail due to different tenant - self._assert_permission_check_fails(dataset, user) - - # ==================== Owner Privilege Tests ==================== - - def test_owner_can_access_any_dataset(self): - """Test that tenant owners can access any dataset regardless of permission level.""" - # Create dataset with restrictive permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) - - # Create owner user - owner_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="owner-999", role=TenantAccountRole.OWNER - ) - - # Owner should have access regardless of dataset permission - self._assert_permission_check_passes(dataset, owner_user) - - # ==================== ONLY_ME Permission Tests ==================== - - def test_only_me_permission_creator_can_access(self): - """Test ONLY_ME permission allows only the dataset creator to access.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should be able to access - self._assert_permission_check_passes(dataset, creator_user) - - def test_only_me_permission_others_cannot_access(self): - """Test ONLY_ME permission denies access to non-creators.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Non-creator should be denied access - self._assert_permission_check_fails(dataset, normal_user) - - # ==================== ALL_TEAM Permission Tests ==================== - - def test_all_team_permission_allows_access(self): - """Test ALL_TEAM permission allows any team member to access the dataset.""" - # Create dataset with ALL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM) - - # Create different types of team members - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - editor_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="editor-456", role=TenantAccountRole.EDITOR - ) - - # All team members should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_permission_check_passes(dataset, editor_user) - - # ==================== PARTIAL_TEAM Permission Tests ==================== - - def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows creator to access without database query.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should have access without database query - self._assert_permission_check_passes(dataset, creator_user) - self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"]) - - def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows users with explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return a permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=normal_user.id - ) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission - - # User with explicit permission should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission denies users without explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # User without explicit permission should be denied access - self._assert_permission_check_fails(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies): - """Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create a different user (not the creator) - other_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="other-user-123", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Non-creator without explicit permission should be denied access - self._assert_permission_check_fails(dataset, other_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id) - - # ==================== Enum Usage Tests ==================== - - def test_partial_team_permission_uses_correct_enum(self): - """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals.""" - # Create dataset with PARTIAL_TEAM permission using enum - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should always have access regardless of permission level - self._assert_permission_check_passes(dataset, creator_user) - - # ==================== Logging Tests ==================== - - def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies): - """Test that permission denied events are properly logged for debugging purposes.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Attempt permission check (should fail) - with pytest.raises(NoPermissionError): - DatasetService.check_dataset_permission(dataset, normal_user) - - # Verify debug message was logged with correct user and dataset information - mock_logging_dependencies["logging"].debug.assert_called_with( - "User %s does not have permission to access dataset %s", normal_user.id, dataset.id - ) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py deleted file mode 100644 index a1d2f6410c..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Unit tests for non-SQL DocumentService orchestration behaviors. - -This file intentionally keeps only collaborator-oriented document indexing -orchestration tests. SQL-backed dataset lifecycle cases are covered by -integration tests under testcontainers. -""" - -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Document -from services.errors.document import DocumentIndexingError - - -class DatasetServiceUnitDataFactory: - """Factory for creating lightweight document doubles used in unit tests.""" - - @staticmethod - def create_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - indexing_status: str = "completed", - is_paused: bool = False, - ) -> Mock: - """Create a document-shaped mock for DocumentService orchestration tests.""" - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.indexing_status = indexing_status - document.is_paused = is_paused - document.paused_by = None - document.paused_at = None - return document - - -class TestDatasetServiceDocumentIndexing: - """Unit tests for pause/recover/retry orchestration without SQL assertions.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Patch non-SQL collaborators used by DocumentService methods.""" - with ( - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - mock_current_user.id = "user-123" - yield { - "redis_client": mock_redis, - "db_session": mock_db, - "current_user": mock_current_user, - } - - def test_pause_document_success(self, mock_document_service_dependencies): - """Pause a document that is currently in an indexable status.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") - - # Act - from services.dataset_service import DocumentService - - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - assert document.paused_by == "user-123" - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( - f"document_{document.id}_is_paused", - "True", - ) - - def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): - """Raise DocumentIndexingError when pausing a completed document.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") - - # Act / Assert - from services.dataset_service import DocumentService - - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - def test_recover_document_success(self, mock_document_service_dependencies): - """Recover a paused document and dispatch the recover indexing task.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) - - # Act - with patch("services.dataset_service.recover_document_indexing_task") as recover_task: - from services.dataset_service import DocumentService - - DocumentService.recover_document(document) - - # Assert - assert document.is_paused is False - assert document.paused_by is None - assert document.paused_at is None - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( - f"document_{document.id}_is_paused" - ) - recover_task.delay.assert_called_once_with(document.dataset_id, document.id) - - def test_retry_document_indexing_success(self, mock_document_service_dependencies): - """Reset documents to waiting state and dispatch retry indexing task.""" - # Arrange - dataset_id = "dataset-123" - documents = [ - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), - ] - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Act - with patch("services.dataset_service.retry_document_indexing_task") as retry_task: - from services.dataset_service import DocumentService - - DocumentService.retry_document(dataset_id, documents) - - # Assert - assert all(document.indexing_status == "waiting" for document in documents) - assert mock_document_service_dependencies["db_session"].add.call_count == 2 - assert mock_document_service_dependencies["db_session"].commit.call_count == 2 - assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 - retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py deleted file mode 100644 index abff48347e..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ /dev/null @@ -1,100 +0,0 @@ -import datetime -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.dataset_service import DocumentService -from tests.unit_tests.conftest import redis_mock - - -class DocumentBatchUpdateTestDataFactory: - """Factory class for creating test data and mock objects for document batch update tests.""" - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - return dataset - - @staticmethod - def create_user_mock(user_id: str = "user-789") -> Mock: - """Create a mock user.""" - user = Mock() - user.id = user_id - return user - - @staticmethod - def create_document_mock( - document_id: str = "doc-1", - name: str = "test_document.pdf", - enabled: bool = True, - archived: bool = False, - indexing_status: str = "completed", - completed_at: datetime.datetime | None = None, - **kwargs, - ) -> Mock: - """Create a mock document with specified attributes.""" - document = Mock(spec=Document) - document.id = document_id - document.name = name - document.enabled = enabled - document.archived = archived - document.indexing_status = indexing_status - document.completed_at = completed_at or datetime.datetime.now() - - document.disabled_at = None - document.disabled_by = None - document.archived_at = None - document.archived_by = None - document.updated_at = None - - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - -class TestDatasetServiceBatchUpdateDocumentStatus: - """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Common mock setup for document service dependencies.""" - with ( - patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_document": mock_get_doc, - "db_session": mock_db, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): - """Test that ValueError is raised when an invalid action is provided.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = doc - - redis_mock.reset_mock() - redis_mock.get.return_value = None - - invalid_action = "invalid_action" - with pytest.raises(ValueError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user - ) - - assert invalid_action in str(exc_info.value) - assert "Invalid action" in str(exc_info.value) - - redis_mock.setex.assert_not_called() diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py deleted file mode 100644 index f8c5270656..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" - -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest - -from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity - - -class TestDatasetServiceCreateRagPipelineDatasetNonSQL: - """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" - - @pytest.fixture - def mock_rag_pipeline_dependencies(self): - """Patch database session and current_user for validation-only unit coverage.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - yield { - "db_session": mock_db, - "current_user_mock": mock_current_user, - } - - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Raise ValueError when current_user.id is unavailable before SQL persistence.""" - # Arrange - tenant_id = str(uuid4()) - mock_rag_pipeline_dependencies["current_user_mock"].id = None - - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="Test Dataset", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act / Assert - with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, - rag_pipeline_dataset_create_entity=entity, - ) diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py new file mode 100644 index 0000000000..92aed7c30a --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -0,0 +1,1760 @@ +"""Unit tests for DatasetService and dataset-related collaborators.""" + +from .dataset_service_test_helpers import ( + CloudPlan, + Dataset, + DatasetCollectionBindingService, + DatasetNameDuplicateError, + DatasetPermissionEnum, + DatasetPermissionService, + DatasetProcessRule, + DatasetService, + DatasetServiceUnitDataFactory, + DocumentIndexingError, + DocumentService, + LLMBadRequestError, + MagicMock, + Mock, + ModelFeature, + ModelType, + NoPermissionError, + NotFound, + PipelineIconInfo, + ProviderTokenNotInitError, + RagPipelineDatasetCreateEntity, + SimpleNamespace, + TenantAccountRole, + _make_knowledge_configuration, + _make_retrieval_model, + _make_session_context, + json, + patch, + pytest, +) + + +class TestDatasetServiceQueries: + """Unit tests for DatasetService query composition and fallback branches.""" + + @pytest.fixture + def mock_dataset_query_dependencies(self): + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped-search") as escape_like, + patch("services.dataset_service.TagService.get_target_ids_by_tag_ids") as get_target_ids, + ): + mock_db.paginate.return_value = SimpleNamespace(items=["dataset"], total=1) + yield { + "db": mock_db, + "escape_like_pattern": escape_like, + "get_target_ids": get_target_ids, + } + + def test_get_datasets_returns_paginated_results_for_public_view(self, mock_dataset_query_dependencies): + items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1") + + assert items == ["dataset"] + assert total == 1 + mock_dataset_query_dependencies["db"].paginate.assert_called_once() + mock_dataset_query_dependencies["escape_like_pattern"].assert_not_called() + + def test_get_datasets_short_circuits_for_dataset_operator_without_permissions( + self, mock_dataset_query_dependencies + ): + user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR) + mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = [] + + items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user) + + assert items == [] + assert total == 0 + mock_dataset_query_dependencies["db"].paginate.assert_not_called() + + def test_get_datasets_short_circuits_when_tag_lookup_returns_no_target_ids(self, mock_dataset_query_dependencies): + mock_dataset_query_dependencies["get_target_ids"].return_value = [] + + items, total = DatasetService.get_datasets( + page=1, + per_page=20, + tenant_id="tenant-1", + tag_ids=["tag-1"], + ) + + assert items == [] + assert total == 0 + mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) + mock_dataset_query_dependencies["db"].paginate.assert_not_called() + + def test_get_datasets_search_and_tag_filters_call_collaborators(self, mock_dataset_query_dependencies): + mock_dataset_query_dependencies["get_target_ids"].return_value = ["dataset-1"] + + items, total = DatasetService.get_datasets( + page=2, + per_page=10, + tenant_id="tenant-1", + search="report", + tag_ids=["tag-1"], + ) + + assert items == ["dataset"] + assert total == 1 + mock_dataset_query_dependencies["escape_like_pattern"].assert_called_once_with("report") + mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) + mock_dataset_query_dependencies["db"].paginate.assert_called_once() + + def test_get_process_rules_returns_latest_rule_when_present(self): + dataset_process_rule = Mock(spec=DatasetProcessRule) + dataset_process_rule.mode = "automatic" + dataset_process_rule.rules_dict = {"delimiter": "\n"} + + with patch("services.dataset_service.db") as mock_db: + ( + mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value + ) = dataset_process_rule + + result = DatasetService.get_process_rules("dataset-1") + + assert result == {"mode": "automatic", "rules": {"delimiter": "\n"}} + + def test_get_process_rules_falls_back_to_default_rules_when_missing(self): + with patch("services.dataset_service.db") as mock_db: + ( + mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value + ) = None + + result = DatasetService.get_process_rules("dataset-1") + + assert result == { + "mode": DocumentService.DEFAULT_RULES["mode"], + "rules": DocumentService.DEFAULT_RULES["rules"], + } + + def test_get_datasets_by_ids_returns_empty_for_missing_ids(self): + with patch("services.dataset_service.db") as mock_db: + items, total = DatasetService.get_datasets_by_ids([], "tenant-1") + + assert items == [] + assert total == 0 + mock_db.paginate.assert_not_called() + + def test_get_datasets_by_ids_uses_paginate_for_non_empty_input(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.paginate.return_value = SimpleNamespace(items=["dataset-1"], total=1) + + items, total = DatasetService.get_datasets_by_ids(["dataset-1"], "tenant-1") + + assert items == ["dataset-1"] + assert total == 1 + mock_db.paginate.assert_called_once() + + def test_get_dataset_returns_first_match(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset + + result = DatasetService.get_dataset(dataset.id) + + assert result is dataset + + +class TestDatasetServiceValidation: + """Unit tests for DatasetService validation helpers.""" + + @pytest.mark.parametrize( + ("dataset_doc_form", "incoming_doc_form"), + [(None, "text_model"), ("text_model", "text_model")], + ) + def test_check_doc_form_allows_matching_or_missing_dataset_doc_form(self, dataset_doc_form, incoming_doc_form): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form=dataset_doc_form) + + DatasetService.check_doc_form(dataset, incoming_doc_form) + + def test_check_doc_form_rejects_mismatched_doc_form(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="qa_model") + + with pytest.raises(ValueError, match="doc_form is different"): + DatasetService.check_doc_form(dataset, "text_model") + + def test_check_dataset_model_setting_skips_non_high_quality_datasets(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="economy") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_dataset_model_setting(dataset) + + model_manager_cls.assert_not_called() + + def test_check_dataset_model_setting_validates_embedding_model_for_high_quality_dataset(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_dataset_model_setting(dataset) + + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + def test_check_dataset_model_setting_wraps_llm_bad_request_error(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Embedding Model available"): + DatasetService.check_dataset_model_setting(dataset) + + def test_check_dataset_model_setting_wraps_provider_token_error(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"): + DatasetService.check_dataset_model_setting(dataset) + + def test_check_embedding_model_setting_wraps_provider_token_error_description(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "provider setup" + ) + + with pytest.raises(ValueError, match="provider setup"): + DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model") + + def test_check_reranking_model_setting_uses_rerank_model_type(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") + + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider", + model_type=ModelType.RERANK, + model="reranker", + ) + + def test_check_reranking_model_setting_wraps_bad_request(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Rerank Model available"): + DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") + + def test_check_is_multimodal_model_returns_true_when_model_supports_vision(self): + model_schema = SimpleNamespace(features=[ModelFeature.VISION]) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + assert result is True + + def test_check_is_multimodal_model_returns_false_when_vision_feature_is_absent(self): + model_schema = SimpleNamespace(features=[]) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + assert result is False + + def test_check_is_multimodal_model_raises_when_schema_is_missing(self): + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + with pytest.raises(ValueError, match="Model schema not found"): + DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + def test_check_is_multimodal_model_wraps_bad_request_error(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Model available"): + DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + +class TestDatasetServiceCreationAndUpdate: + """Unit tests for dataset creation and update helpers.""" + + def test_create_empty_dataset_raises_when_name_already_exists(self): + account = SimpleNamespace(id="user-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"): + DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account) + + def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self): + account = SimpleNamespace(id="user-1") + default_embedding_model = SimpleNamespace(provider="provider", model_name="default-embedding") + + with ( + patch("services.dataset_service.db") as mock_db, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), + ), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model + + dataset = DatasetService.create_empty_dataset( + tenant_id="tenant-1", + name="Dataset", + description="Description", + indexing_technique="high_quality", + account=account, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "default-embedding" + assert dataset.permission == DatasetPermissionEnum.ONLY_ME + assert dataset.provider == "vendor" + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.TEXT_EMBEDDING, + ) + check_embedding.assert_not_called() + mock_db.session.commit.assert_called_once() + + def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset(self): + account = SimpleNamespace(id="user-1") + retrieval_model = _make_retrieval_model() + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.db") as mock_db, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), + ), + patch( + "services.dataset_service.ExternalKnowledgeBindings", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) as binding_cls, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()), + patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, + patch.object(DatasetService, "check_reranking_model_setting") as check_reranking, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + dataset = DatasetService.create_empty_dataset( + tenant_id="tenant-1", + name="External Dataset", + description="Description", + indexing_technique="high_quality", + account=account, + permission=DatasetPermissionEnum.ALL_TEAM, + provider="external", + external_knowledge_api_id="api-1", + external_knowledge_id="knowledge-1", + embedding_model_provider="provider", + embedding_model_name="embedding-model", + retrieval_model=retrieval_model, + summary_index_setting={"enable": True}, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "embedding-model" + assert dataset.retrieval_model == retrieval_model.model_dump() + assert dataset.summary_index_setting == {"enable": True} + check_embedding.assert_called_once_with("tenant-1", "provider", "embedding-model") + check_reranking.assert_called_once_with("tenant-1", "rerank-provider", "rerank-model") + binding_cls.assert_called_once_with( + tenant_id="tenant-1", + dataset_id="dataset-1", + external_knowledge_api_id="api-1", + external_knowledge_id="knowledge-1", + created_by="user-1", + ) + assert mock_db.session.add.call_count == 2 + mock_db.session.commit.assert_called_once() + + def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self): + entity = RagPipelineDatasetCreateEntity( + name="Existing Dataset", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"): + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self): + entity = RagPipelineDatasetCreateEntity( + name="", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + pipeline = SimpleNamespace(id="pipeline-1") + + def pipeline_factory(**kwargs): + pipeline.__dict__.update(kwargs) + return pipeline + + def dataset_factory(**kwargs): + return SimpleNamespace(id="dataset-1", **kwargs) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name, + patch("services.dataset_service.Pipeline", side_effect=pipeline_factory), + patch("services.dataset_service.Dataset", side_effect=dataset_factory), + ): + mock_db.session.query.return_value.filter_by.return_value.all.return_value = [ + SimpleNamespace(name="Untitled"), + SimpleNamespace(name="Untitled 1"), + ] + + dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + assert entity.name == "Untitled 2" + assert dataset.pipeline_id == "pipeline-1" + assert dataset.runtime_mode == "rag_pipeline" + generate_name.assert_called_once_with(["Untitled", "Untitled 1"], "Untitled") + mock_db.session.commit.assert_called_once() + + def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self): + entity = RagPipelineDatasetCreateEntity( + name="Dataset", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.current_user", SimpleNamespace(id=None)), + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + def test_update_dataset_raises_when_dataset_is_missing(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(ValueError, match="Dataset not found"): + DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1")) + + def test_update_dataset_raises_when_new_name_conflicts(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.name = "Old Dataset" + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "_has_dataset_same_name", return_value=True), + ): + with pytest.raises(ValueError, match="Dataset name already exists"): + DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1")) + + def test_update_dataset_routes_external_datasets_to_external_helper(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.provider = "external" + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external, + ): + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + + assert result == "updated" + check_permission.assert_called_once_with(dataset, user) + update_external.assert_called_once_with(dataset, {"name": dataset.name}, user) + + def test_update_dataset_routes_internal_datasets_to_internal_helper(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.provider = "vendor" + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal, + ): + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + + assert result == "updated" + check_permission.assert_called_once_with(dataset, user) + update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user) + + def test_has_dataset_same_name_returns_true_when_query_matches(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = object() + + result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset") + + assert result is True + + def test_update_external_dataset_updates_dataset_and_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + user = SimpleNamespace(id="user-1") + now = object() + + with ( + patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding, + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService._update_external_dataset( + dataset, + { + "external_retrieval_model": {"top_k": 3}, + "summary_index_setting": {"enable": True}, + "name": "Updated Dataset", + "description": "Updated description", + "permission": DatasetPermissionEnum.PARTIAL_TEAM, + "external_knowledge_id": "knowledge-1", + "external_knowledge_api_id": "api-1", + }, + user, + ) + + assert result is dataset + assert dataset.retrieval_model == {"top_k": 3} + assert dataset.summary_index_setting == {"enable": True} + assert dataset.name == "Updated Dataset" + assert dataset.description == "Updated description" + assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM + assert dataset.updated_by == "user-1" + assert dataset.updated_at is now + update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1") + mock_db.session.add.assert_called_once_with(dataset) + mock_db.session.commit.assert_called_once() + + @pytest.mark.parametrize( + ("payload", "message"), + [ + ({"external_knowledge_api_id": "api-1"}, "External knowledge id is required"), + ({"external_knowledge_id": "knowledge-1"}, "External knowledge api id is required"), + ], + ) + def test_update_external_dataset_requires_external_binding_fields(self, payload, message): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + + with pytest.raises(ValueError, match=message): + DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1")) + + def test_update_external_knowledge_binding_updates_changed_binding_values(self): + binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api") + session = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = binding + session_context = _make_session_context(session) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.Session", return_value=session_context), + ): + DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api") + + assert binding.external_knowledge_id == "new-knowledge" + assert binding.external_knowledge_api_id == "new-api" + mock_db.session.add.assert_called_once_with(binding) + + def test_update_external_knowledge_binding_raises_for_missing_binding(self): + session = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = None + session_context = _make_session_context(session) + + with ( + patch("services.dataset_service.db"), + patch("services.dataset_service.Session", return_value=session_context), + ): + with pytest.raises(ValueError, match="External knowledge binding not found"): + DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1") + + def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + user = SimpleNamespace(id="user-1") + now = object() + update_payload = { + "name": "Updated Dataset", + "description": None, + "partial_member_list": [{"user_id": "member-1"}], + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "knowledge-1", + "external_retrieval_model": {"top_k": 2}, + "retrieval_model": {"top_k": 4}, + "summary_index_setting": {"enable": True}, + "icon_info": {"icon": "book"}, + } + + with ( + patch.object(DatasetService, "_handle_indexing_technique_change", return_value="update"), + patch.object(DatasetService, "_update_pipeline_knowledge_base_node_data") as update_pipeline, + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task, + patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task, + ): + result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user) + + assert result is dataset + updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0] + assert updated_values["name"] == "Updated Dataset" + assert updated_values["description"] is None + assert updated_values["retrieval_model"] == {"top_k": 4} + assert updated_values["summary_index_setting"] == {"enable": True} + assert updated_values["icon_info"] == {"icon": "book"} + assert updated_values["updated_by"] == "user-1" + assert updated_values["updated_at"] is now + assert "partial_member_list" not in updated_values + assert "external_knowledge_api_id" not in updated_values + assert "external_knowledge_id" not in updated_values + assert "external_retrieval_model" not in updated_values + mock_db.session.commit.assert_called_once() + mock_db.session.refresh.assert_called_once_with(dataset) + update_pipeline.assert_called_once_with(dataset, "user-1") + vector_task.delay.assert_called_once_with("dataset-1", "update") + regenerate_task.delay.assert_called_once_with( + "dataset-1", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline_dataset(self): + dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1") + + with patch("services.dataset_service.db") as mock_db: + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.query.assert_not_called() + + def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self): + dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.commit.assert_not_called() + + def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_workflows(self): + dataset = SimpleNamespace( + id="dataset-1", + runtime_mode="rag_pipeline", + pipeline_id="pipeline-1", + embedding_model="embedding-model", + embedding_model_provider="provider", + retrieval_model={"top_k": 5}, + chunk_structure="paragraph", + indexing_technique="high_quality", + keyword_number=8, + summary_index_setting={"enable": True}, + ) + pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1") + published_workflow = SimpleNamespace( + graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}, {"data": {"type": "start"}}]}), + type="chat", + features={"feature": True}, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}]})) + new_workflow = SimpleNamespace(id="workflow-1") + rag_pipeline_service = MagicMock() + rag_pipeline_service.get_published_workflow.return_value = published_workflow + rag_pipeline_service.get_draft_workflow.return_value = draft_workflow + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), + patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + published_graph = json.loads(workflow_new.call_args.kwargs["graph"]) + assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model" + assert published_graph["nodes"][0]["data"]["summary_index_setting"] == {"enable": True} + assert json.loads(draft_workflow.graph)["nodes"][0]["data"]["embedding_model_provider"] == "provider" + mock_db.session.add.assert_any_call(new_workflow) + mock_db.session.add.assert_any_call(draft_workflow) + mock_db.session.commit.assert_called_once() + + def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(self): + dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") + pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1") + rag_pipeline_service = MagicMock() + rag_pipeline_service.get_published_workflow.side_effect = RuntimeError("boom") + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + + with pytest.raises(RuntimeError, match="boom"): + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.rollback.assert_called_once() + + def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="economy") + + result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data) + + assert result is None + assert filtered_data == {} + + def test_handle_indexing_technique_change_switches_to_economy(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="high_quality") + + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "economy"}, + filtered_data, + ) + + assert result == "remove" + assert filtered_data == { + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + } + + def test_handle_indexing_technique_change_switches_to_high_quality(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="economy") + + with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding: + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + ) + + assert result == "add" + configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data) + + def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="high_quality") + + with patch.object( + DatasetService, + "_handle_embedding_model_update_when_technique_unchanged", + return_value="update", + ) as update_embedding: + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + ) + + assert result == "update" + update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data) + + def test_configure_embedding_model_for_high_quality_updates_filtered_data(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService._configure_embedding_model_for_high_quality( + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model": "embedding-model", + "embedding_model_provider": "provider", + "collection_binding_id": "binding-1", + } + + @pytest.mark.parametrize( + ("error", "message"), + [ + (LLMBadRequestError(), "No Embedding Model available"), + (ProviderTokenNotInitError("provider setup"), "provider setup"), + ], + ) + def test_configure_embedding_model_for_high_quality_wraps_model_errors(self, error, message): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error + + with pytest.raises(ValueError, match=message): + DatasetService._configure_embedding_model_for_high_quality( + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + {}, + ) + + def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + filtered_data: dict[str, object] = {} + + with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings: + result = DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, + {}, + filtered_data, + ) + + assert result is None + preserve_settings.assert_called_once_with(dataset, filtered_data) + + def test_handle_embedding_model_update_when_technique_unchanged_updates_when_model_is_provided(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings: + result = DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + assert result == "update" + update_settings.assert_called_once() + + def test_preserve_existing_embedding_settings_keeps_current_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + collection_binding_id="binding-1", + ) + filtered_data = {"embedding_model_provider": "", "embedding_model": ""} + + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + + assert filtered_data == { + "embedding_model_provider": "provider", + "embedding_model": "embedding-model", + "collection_binding_id": "binding-1", + } + + def test_preserve_existing_embedding_settings_removes_empty_placeholders_without_existing_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider=None, + embedding_model=None, + collection_binding_id=None, + ) + filtered_data = {"embedding_model_provider": "", "embedding_model": ""} + + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + + assert filtered_data == {} + + def test_update_embedding_model_settings_returns_update_for_changed_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings: + result = DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + assert result == "update" + apply_settings.assert_called_once() + + def test_update_embedding_model_settings_returns_none_for_unchanged_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + result = DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + {}, + ) + + assert result is None + + def test_update_embedding_model_settings_wraps_bad_request_errors(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()): + with pytest.raises(ValueError, match="No Embedding Model available"): + DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + def test_apply_new_embedding_settings_updates_binding_for_new_model(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1") + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( + provider="provider-two", + model_name="embedding-model-two", + ) + + DatasetService._apply_new_embedding_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model": "embedding-model-two", + "embedding_model_provider": "provider-two", + "collection_binding_id": "binding-2", + } + + def test_apply_new_embedding_settings_preserves_existing_values_when_provider_token_is_missing(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + collection_binding_id="binding-1", + ) + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + DatasetService._apply_new_embedding_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model_provider": "provider", + "embedding_model": "embedding-model", + "collection_binding_id": "binding-1", + } + + @pytest.mark.parametrize( + ("summary_index_setting", "expected"), + [ + (None, False), + ({"enable": False}, False), + ({"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, False), + ({"enable": True, "model_name": "new-model", "model_provider_name": "provider-two"}, True), + ], + ) + def test_check_summary_index_setting_model_changed(self, summary_index_setting, expected): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", + summary_index_setting={"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, + ) + + result = DatasetService._check_summary_index_setting_model_changed( + dataset, + {"summary_index_setting": summary_index_setting} if summary_index_setting is not None else {}, + ) + + assert result is expected + + +class TestDatasetServiceRagPipelineSettings: + """Unit tests for rag-pipeline dataset setting updates.""" + + def test_update_rag_pipeline_dataset_settings_requires_current_tenant(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + knowledge_configuration = _make_knowledge_configuration() + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id=None)): + with pytest.raises(ValueError, match="Current user or current tenant not found"): + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + def test_update_rag_pipeline_dataset_settings_without_published_high_quality_updates_embedding_settings(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration(summary_index_setting={"enable": True}) + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=True) as check_multimodal, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + assert dataset.chunk_structure == "paragraph" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "embedding-model" + assert dataset.embedding_model_provider == "provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.is_multimodal is True + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + assert dataset.summary_index_setting == {"enable": True} + check_multimodal.assert_called_once_with("tenant-1", "provider", "embedding-model") + session.add.assert_called_once_with(dataset) + session.commit.assert_not_called() + + def test_update_rag_pipeline_dataset_settings_without_published_economy_updates_keyword_number(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + keyword_number=12, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + assert dataset.indexing_technique == "economy" + assert dataset.keyword_number == 12 + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + + def test_update_rag_pipeline_dataset_settings_with_published_rejects_chunk_structure_changes(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration(chunk_structure="sentence") + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + def test_update_rag_pipeline_dataset_settings_with_published_rejects_switch_to_economy(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + with pytest.raises( + ValueError, + match="Knowledge base indexing technique is not allowed to be updated to economy", + ): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + def test_update_rag_pipeline_dataset_settings_with_published_adds_high_quality_index(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "economy" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration() + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=False), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "embedding-model" + assert dataset.embedding_model_provider == "provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.is_multimodal is False + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "add") + + def test_update_rag_pipeline_dataset_settings_with_published_updates_changed_embedding_model(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + embedding_model_provider="provider-two", + embedding_model="embedding-model-two", + summary_index_setting={"enable": True}, + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=True), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( + provider="provider-two", + model_name="embedding-model-two", + ) + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.embedding_model_provider == "provider-two" + assert dataset.embedding_model == "embedding-model-two" + assert dataset.collection_binding_id == "binding-2" + assert dataset.is_multimodal is True + assert dataset.summary_index_setting == {"enable": True} + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "update") + + def test_update_rag_pipeline_dataset_settings_with_published_skips_embedding_update_when_token_is_missing(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + embedding_model_provider="provider-two", + embedding_model="embedding-model-two", + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "embedding-model" + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "update") + + def test_update_rag_pipeline_dataset_settings_with_published_updates_economy_keyword_number(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "economy" + dataset.keyword_number = 5 + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + keyword_number=9, + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.keyword_number == 9 + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_not_called() + + +class TestDatasetServicePermissionsAndLifecycle: + """Unit tests for dataset permissions, deletion, and metadata helpers.""" + + def test_delete_dataset_returns_false_when_dataset_is_missing(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + result = DatasetService.delete_dataset("dataset-1", user=SimpleNamespace(id="user-1")) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal, + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService.delete_dataset(dataset.id, user=SimpleNamespace(id="user-1")) + + assert result is True + check_permission.assert_called_once_with(dataset, SimpleNamespace(id="user-1")) + send_deleted_signal.assert_called_once_with(dataset) + mock_db.session.delete.assert_called_once_with(dataset) + mock_db.session.commit.assert_called_once() + + def test_dataset_use_check_returns_scalar_result(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.execute.return_value.scalar_one.return_value = True + + result = DatasetService.dataset_use_check("dataset-1") + + assert result is True + + def test_check_dataset_permission_rejects_cross_tenant_access(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(tenant_id="tenant-a") + user = DatasetServiceUnitDataFactory.create_user_mock(tenant_id="tenant-b") + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.ONLY_ME, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_allows_partial_team_creator_without_lookup(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="creator-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="creator-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + DatasetService.check_dataset_permission(dataset, user) + + mock_db.session.query.assert_not_called() + + def test_check_dataset_permission_allows_partial_team_member_with_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_operator_permission_validates_required_arguments(self): + with pytest.raises(ValueError, match="Dataset not found"): + DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None) + + with pytest.raises(ValueError, match="User not found"): + DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1")) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.ONLY_ME, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.all.return_value = [] + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + def test_get_dataset_queries_delegates_to_paginate(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.desc.side_effect = lambda column: column + mock_db.paginate.return_value = SimpleNamespace(items=["query"], total=1) + + items, total = DatasetService.get_dataset_queries("dataset-1", page=1, per_page=20) + + assert items == ["query"] + assert total == 1 + mock_db.paginate.assert_called_once() + + def test_get_related_apps_returns_ordered_query_results(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.desc.side_effect = lambda column: column + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + "relation-1" + ] + + result = DatasetService.get_related_apps("dataset-1") + + assert result == ["relation-1"] + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status("dataset-1", True) + + def test_update_dataset_api_status_requires_current_user_id(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch("services.dataset_service.current_user", SimpleNamespace(id=None)), + ): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) + now = object() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + assert dataset.enable_api is True + assert dataset.updated_by == "user-1" + assert dataset.updated_at is now + mock_db.session.commit.assert_called_once() + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) + ) + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService.get_dataset_auto_disable_logs("dataset-1") + + assert result == {"document_ids": [], "count": 0} + mock_db.session.scalars.assert_not_called() + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + logs = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) + ) + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.scalars.return_value.all.return_value = logs + + result = DatasetService.get_dataset_auto_disable_logs("dataset-1") + + assert result == {"document_ids": ["doc-1", "doc-2"], "count": 2} + + +class TestDatasetServiceDocumentIndexing: + """Unit tests for pause/recover/retry orchestration without SQL assertions.""" + + @pytest.fixture + def mock_document_service_dependencies(self): + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db.session") as mock_db_session, + patch("services.dataset_service.current_user") as mock_current_user, + ): + mock_current_user.id = "user-123" + yield { + "redis_client": mock_redis, + "db_session": mock_db_session, + "current_user": mock_current_user, + } + + def test_pause_document_success(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") + + DocumentService.pause_document(document) + + assert document.is_paused is True + assert document.paused_by == "user-123" + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( + f"document_{document.id}_is_paused", + "True", + ) + + def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") + + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(document) + + def test_recover_document_success(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) + + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(document) + + assert document.is_paused is False + assert document.paused_by is None + assert document.paused_at is None + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( + f"document_{document.id}_is_paused" + ) + recover_task.delay.assert_called_once_with(document.dataset_id, document.id) + + def test_retry_document_indexing_success(self, mock_document_service_dependencies): + dataset_id = "dataset-123" + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), + ] + mock_document_service_dependencies["redis_client"].get.return_value = None + + with patch("services.dataset_service.retry_document_indexing_task") as retry_task: + DocumentService.retry_document(dataset_id, documents) + + assert all(document.indexing_status == "waiting" for document in documents) + assert mock_document_service_dependencies["db_session"].add.call_count == 2 + assert mock_document_service_dependencies["db_session"].commit.call_count == 2 + assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 + retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") + + +class TestDatasetCollectionBindingService: + """Unit tests for dataset collection binding lookups and creation.""" + + def test_get_dataset_collection_binding_returns_existing_binding(self): + binding = SimpleNamespace(id="binding-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result is binding + mock_db.session.add.assert_not_called() + + def test_get_dataset_collection_binding_creates_binding_when_missing(self): + created_binding = SimpleNamespace(id="binding-2") + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls, + patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset") + + assert result is created_binding + binding_cls.assert_called_once_with( + provider_name="provider", + model_name="model", + collection_name="generated-collection", + type="dataset", + ) + mock_db.session.add.assert_called_once_with(created_binding) + mock_db.session.commit.assert_called_once() + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self): + binding = SimpleNamespace(id="binding-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") + + assert result is binding + + +class TestDatasetPermissionService: + """Unit tests for dataset partial-member management helpers.""" + + def test_get_dataset_partial_member_list_returns_scalar_results(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = ["user-1", "user-2"] + + result = DatasetPermissionService.get_dataset_partial_member_list("dataset-1") + + assert result == ["user-1", "user-2"] + + def test_update_partial_member_list_replaces_permissions_and_commits(self): + with patch("services.dataset_service.db") as mock_db: + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}, {"user_id": "user-2"}], + ) + + mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.add_all.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_update_partial_member_list_rolls_back_on_exception(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.add_all.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}], + ) + + mock_db.session.rollback.assert_called_once() + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, "all_team", []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team") + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, "only_me", []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members") + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, "partial_members", []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", permission="partial_members" + ) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + "partial_members", + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", permission="partial_members" + ) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + "partial_members", + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self): + with patch("services.dataset_service.db") as mock_db: + DatasetPermissionService.clear_partial_member_list("dataset-1") + + mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_clear_partial_member_list_rolls_back_on_exception(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.clear_partial_member_list("dataset-1") + + mock_db.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py new file mode 100644 index 0000000000..c8036487ab --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -0,0 +1,2078 @@ +"""Unit tests for DocumentService behaviors in dataset_service.""" + +from .dataset_service_test_helpers import ( + Account, + BuiltInField, + CloudPlan, + DatasetProcessRule, + DatasetService, + DatasetServiceUnitDataFactory, + DataSource, + DocumentIndexingError, + DocumentService, + FileInfo, + FileNotExistsError, + Forbidden, + IndexStructureType, + InfoList, + KnowledgeConfig, + MagicMock, + NoPermissionError, + NotFound, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + ProcessRule, + RerankingModel, + RetrievalMethod, + RetrievalModel, + Rule, + Segmentation, + SimpleNamespace, + WebsiteInfo, + _make_dataset, + _make_document, + _make_features, + _make_lock_context, + _make_session_context, + _make_upload_knowledge_config, + create_autospec, + json, + patch, + pytest, +) + + +class TestDocumentServiceDisplayStatus: + """Unit tests for DocumentService display-status helpers.""" + + @pytest.mark.parametrize( + ("raw_status", "expected"), + [ + ("enabled", "available"), + ("AVAILABLE", "available"), + ("paused", "paused"), + ("unknown", None), + (None, None), + ], + ) + def test_normalize_display_status(self, raw_status, expected): + assert DocumentService.normalize_display_status(raw_status) == expected + + def test_build_display_status_filters_returns_empty_tuple_for_unknown_status(self): + assert DocumentService.build_display_status_filters("missing") == () + + def test_apply_display_status_filter_returns_original_query_for_unknown_status(self): + query = MagicMock() + + result = DocumentService.apply_display_status_filter(query, "missing") + + assert result is query + query.where.assert_not_called() + + def test_apply_display_status_filter_applies_where_for_known_status(self): + query = MagicMock() + filtered_query = MagicMock() + query.where.return_value = filtered_query + + result = DocumentService.apply_display_status_filter(query, "enabled") + + assert result is filtered_query + query.where.assert_called_once() + + +class TestDocumentServiceQueryAndDownloadHelpers: + """Unit tests for DocumentService query helpers and download flows.""" + + def test_get_document_returns_none_when_document_id_is_missing(self): + with patch("services.dataset_service.db") as mock_db: + result = DocumentService.get_document("dataset-1", None) + + assert result is None + mock_db.session.query.assert_not_called() + + def test_get_document_queries_by_dataset_and_document_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = document + + result = DocumentService.get_document("dataset-1", "doc-1") + + assert result is document + + def test_get_documents_by_ids_returns_empty_for_empty_input(self): + with patch("services.dataset_service.db") as mock_db: + result = DocumentService.get_documents_by_ids("dataset-1", []) + + assert result == [] + mock_db.session.scalars.assert_not_called() + + def test_get_documents_by_ids_uses_single_batch_query(self): + document = DatasetServiceUnitDataFactory.create_document_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_documents_by_ids("dataset-1", ["doc-1"]) + + assert result == [document] + mock_db.session.scalars.assert_called_once() + + def test_update_documents_need_summary_returns_zero_for_empty_input(self): + with patch("services.dataset_service.session_factory") as session_factory_mock: + result = DocumentService.update_documents_need_summary("dataset-1", []) + + assert result == 0 + session_factory_mock.create_session.assert_not_called() + + def test_update_documents_need_summary_updates_matching_documents_and_commits(self): + session = MagicMock() + session.query.return_value.filter.return_value.update.return_value = 2 + + with patch("services.dataset_service.session_factory") as session_factory_mock: + session_factory_mock.create_session.return_value = _make_session_context(session) + + result = DocumentService.update_documents_need_summary( + "dataset-1", + ["doc-1", "doc-2"], + need_summary=False, + ) + + assert result == 2 + session.commit.assert_called_once() + + def test_get_document_download_url_uses_upload_file_lookup_and_signed_url_helper(self): + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + document = DatasetServiceUnitDataFactory.create_document_mock() + + with ( + patch.object(DocumentService, "_get_upload_file_for_upload_file_document", return_value=upload_file), + patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url, + ): + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id="file-1", as_attachment=True) + + def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_type="not-upload-file") + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_info_dict={}) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + def test_get_upload_file_id_for_upload_file_document_returns_string_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_info_dict={"upload_file_id": 99}) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + def test_get_upload_file_for_upload_file_document_returns_upload_file(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + + with patch( + "services.dataset_service.FileService.get_upload_files_by_ids", return_value={"file-1": upload_file} + ): + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result is upload_file + + def test_enrich_documents_with_summary_index_status_skips_lookup_when_summary_is_disabled(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(summary_index_setting={"enable": False}) + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", need_summary=False), + ] + + DocumentService.enrich_documents_with_summary_index_status(documents, dataset, tenant_id="tenant-1") + + assert documents[0].summary_index_status is None + assert documents[1].summary_index_status is None + + def test_enrich_documents_with_summary_index_status_applies_summary_status_map(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", + summary_index_setting={"enable": True}, + ) + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-3", need_summary=False), + ] + + with patch( + "services.summary_index_service.SummaryIndexService.get_documents_summary_index_status", + return_value={"doc-1": "completed", "doc-2": None}, + ) as get_status_map: + DocumentService.enrich_documents_with_summary_index_status(documents, dataset, tenant_id="tenant-1") + + get_status_map.assert_called_once_with( + document_ids=["doc-1", "doc-2"], + dataset_id="dataset-1", + tenant_id="tenant-1", + ) + assert documents[0].summary_index_status == "completed" + assert documents[1].summary_index_status is None + assert documents[2].summary_index_status is None + + def test_generate_document_batch_download_zip_filename_uses_zip_extension(self): + fake_uuid = SimpleNamespace(hex="archive-id") + + with patch("services.dataset_service.uuid.uuid4", return_value=fake_uuid): + result = DocumentService._generate_document_batch_download_zip_filename() + + assert result == "archive-id.zip" + + def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(self): + with patch.object(DocumentService, "get_documents_by_ids", return_value=[]): + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-other", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with patch.object(DocumentService, "get_documents_by_ids", return_value=[document]): + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with ( + patch.object(DocumentService, "get_documents_by_ids", return_value=[document]), + patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}), + ): + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(self): + document_a = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + document_b = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-2", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-2"}, + ) + upload_file_a = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + upload_file_b = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-2") + + with ( + patch.object(DocumentService, "get_documents_by_ids", return_value=[document_a, document_b]), + patch( + "services.dataset_service.FileService.get_upload_files_by_ids", + return_value={"file-1": upload_file_a, "file-2": upload_file_b}, + ), + ): + result = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1", "doc-2"], + tenant_id="tenant-1", + ) + + assert result == {"doc-1": upload_file_a, "doc-2": upload_file_b} + + def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(self): + user = DatasetServiceUnitDataFactory.create_user_mock() + + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission", side_effect=NoPermissionError("blocked")), + ): + with pytest.raises(Forbidden, match="blocked"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=["doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + user = DatasetServiceUnitDataFactory.create_user_mock() + upload_file_a = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-a") + upload_file_b = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-b") + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission"), + patch.object( + DocumentService, + "_get_upload_files_by_document_id_for_zip_download", + return_value={"doc-1": upload_file_a, "doc-2": upload_file_b}, + ), + patch.object(DocumentService, "_generate_document_batch_download_zip_filename", return_value="archive.zip"), + ): + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=["doc-2", "doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + assert upload_files == [upload_file_b, upload_file_a] + assert download_name == "archive.zip" + + def test_get_document_by_dataset_id_returns_enabled_documents(self): + document = DatasetServiceUnitDataFactory.create_document_mock(enabled=True) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_document_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_working_documents_by_dataset_id_returns_scalars_result(self): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed", archived=False) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_working_documents_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_error_documents_by_dataset_id_returns_scalars_result(self): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="error") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_error_documents_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_batch_documents_filters_by_current_user_tenant(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + document = DatasetServiceUnitDataFactory.create_document_mock() + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_batch_documents("dataset-1", "batch-1") + + assert result == [document] + + def test_get_document_file_detail_returns_one_or_none(self): + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.one_or_none.return_value = upload_file + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is upload_file + + +class TestDocumentServiceMutations: + """Unit tests for DocumentService mutation and orchestration helpers.""" + + @pytest.fixture + def rename_account_context(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.id = "user-123" + current_user.current_tenant_id = "tenant-123" + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + ): + yield current_user + + @pytest.mark.parametrize(("archived", "expected"), [(True, True), (False, False)]) + def test_check_archived_returns_boolean_status(self, archived, expected): + document = DatasetServiceUnitDataFactory.create_document_mock(archived=archived) + + assert DocumentService.check_archived(document) is expected + + def test_delete_document_emits_signal_and_commits(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + data_source_type="upload_file", + data_source_info='{"upload_file_id": "file-1"}', + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with ( + patch("services.dataset_service.document_was_deleted.send") as send_deleted_signal, + patch("services.dataset_service.db") as mock_db, + ): + DocumentService.delete_document(document) + + send_deleted_signal.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id="file-1", + ) + mock_db.session.delete.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + + def test_delete_documents_ignores_empty_input(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with patch("services.dataset_service.db") as mock_db: + DocumentService.delete_documents(dataset, []) + + mock_db.session.scalars.assert_not_called() + + def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="text_model") + document_a = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + document_b = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-2", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-2"}, + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.batch_clean_document_task") as clean_task, + ): + mock_db.session.scalars.return_value.all.return_value = [document_a, document_b] + + DocumentService.delete_documents(dataset, ["doc-1", "doc-2"]) + + assert mock_db.session.delete.call_count == 2 + mock_db.session.commit.assert_called_once() + clean_task.delay.assert_called_once_with(["doc-1", "doc-2"], dataset.id, dataset.doc_form, ["file-1", "file-2"]) + + def test_rename_document_raises_when_dataset_is_missing(self, rename_account_context): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("dataset-1", "doc-1", "New Name") + + def test_rename_document_raises_when_document_is_missing(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=None), + ): + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "doc-1", "New Name") + + def test_rename_document_rejects_cross_tenant_access(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + document = DatasetServiceUnitDataFactory.create_document_mock(tenant_id="tenant-other") + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=document), + ): + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "New Name") + + def test_rename_document_updates_document_metadata_and_upload_file_name(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + built_in_field_enabled=True, + tenant_id="tenant-1", + ) + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + doc_metadata={"title": "Old"}, + data_source_info_dict={"upload_file_id": "file-1"}, + ) + rename_account_context.current_tenant_id = "tenant-1" + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.db") as mock_db, + ): + result = DocumentService.rename_document(dataset.id, document.id, "New Name") + + assert result is document + assert document.name == "New Name" + assert document.doc_metadata[BuiltInField.document_name] == "New Name" + mock_db.session.add.assert_called_once_with(document) + mock_db.session.query.return_value.where.return_value.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_recover_document_raises_when_document_is_not_paused(self): + document = DatasetServiceUnitDataFactory.create_document_mock(is_paused=False) + + with pytest.raises(DocumentIndexingError): + DocumentService.recover_document(document) + + def test_retry_document_raises_when_retry_flag_is_already_set(self): + document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="being retried"): + DocumentService.retry_document("dataset-1", [document]) + + def test_sync_website_document_raises_when_sync_flag_exists(self): + document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="being synced"): + DocumentService.sync_website_document("dataset-1", document) + + def test_sync_website_document_updates_status_sets_cache_and_dispatches_task(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + data_source_info_dict={"mode": "crawl"}, + ) + document.data_source_info = "{}" + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.sync_website_document_indexing_task") as sync_task, + ): + mock_redis.get.return_value = None + + DocumentService.sync_website_document("dataset-1", document) + + assert document.indexing_status == "waiting" + assert '"mode": "scrape"' in document.data_source_info + mock_db.session.add.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with("document_doc-1_is_sync", 600, 1) + sync_task.delay.assert_called_once_with("dataset-1", "doc-1") + + def test_get_documents_position_returns_next_position_when_documents_exist(self): + document = DatasetServiceUnitDataFactory.create_document_mock(position=7) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = ( + document + ) + + result = DocumentService.get_documents_position("dataset-1") + + assert result == 8 + + def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = None + + result = DocumentService.get_documents_position("dataset-1") + + assert result == 1 + + +class TestDocumentServiceSaveDocumentWithoutDatasetId: + """Unit tests for dataset creation around save_document_without_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_save_document_without_dataset_id_creates_high_quality_dataset_with_default_retrieval_model( + self, account_context + ): + knowledge_config = KnowledgeConfig( + indexing_technique="high_quality", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + embedding_model="embedding-model", + embedding_model_provider="provider", + summary_index_setting={"enable": True}, + is_multimodal=True, + ) + created_dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="", + description=None, + ) + first_document = SimpleNamespace(name="VeryLongDocumentNameForDataset.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ) as dataset_cls, + patch.object( + DocumentService, "save_document_with_dataset_id", return_value=([first_document], "batch-1") + ) as save_document, + patch("services.dataset_service.db") as mock_db, + ): + dataset, documents, batch = DocumentService.save_document_without_dataset_id( + tenant_id="tenant-1", + knowledge_config=knowledge_config, + account=account_context, + ) + + assert dataset is created_dataset + assert documents == [first_document] + assert batch == "batch-1" + assert created_dataset.collection_binding_id == "binding-1" + assert created_dataset.retrieval_model["search_method"] == RetrievalMethod.SEMANTIC_SEARCH + assert created_dataset.retrieval_model["top_k"] == 4 + assert created_dataset.summary_index_setting == {"enable": True} + assert created_dataset.is_multimodal is True + assert created_dataset.name == first_document.name[:18] + "..." + assert ( + created_dataset.description + == "useful for when you want to answer queries about the VeryLongDocumentNameForDataset.txt" + ) + dataset_cls.assert_called_once() + save_document.assert_called_once_with(created_dataset, knowledge_config, account_context) + assert mock_db.session.commit.call_count == 1 + + def test_save_document_without_dataset_id_uses_provided_retrieval_model(self, account_context): + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + top_k=9, + score_threshold_enabled=True, + score_threshold=0.6, + ) + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + retrieval_model=retrieval_model, + ) + created_dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="", description=None) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ), + patch.object( + DocumentService, + "save_document_with_dataset_id", + return_value=([SimpleNamespace(name="Doc")], "batch-1"), + ), + patch("services.dataset_service.db"), + ): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + assert created_dataset.retrieval_model == retrieval_model.model_dump() + assert created_dataset.collection_binding_id is None + + def test_save_document_without_dataset_id_rejects_sandbox_batch_upload(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1", "file-2"]), + ) + ), + ) + + with ( + patch( + "services.dataset_service.FeatureService.get_features", + return_value=_make_features(enabled=True, plan=CloudPlan.SANDBOX), + ), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="does not support batch upload"): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_not_called() + + +class TestDocumentServiceUpdateDocumentWithDatasetId: + """Unit tests for the document-update orchestration path.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_update_document_with_dataset_id_raises_when_document_is_missing(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=None), + patch.object(DatasetService, "check_dataset_model_setting") as check_model_setting, + ): + with pytest.raises(NotFound, match="Document not found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + check_model_setting.assert_called_once_with(dataset) + + def test_update_document_with_dataset_id_rejects_non_available_documents(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(display_status="indexing") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="Document is not available"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_upload_file_process_rule_and_name_override(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document.dataset_process_rule_id = "old-rule" + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="custom", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + name="Renamed document", + doc_form=IndexStructureType.QA_INDEX, + ) + created_process_rule = SimpleNamespace(id="rule-2") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.DatasetProcessRule", return_value=created_process_rule), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + upload_query = MagicMock() + upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 3 + mock_db.session.query.side_effect = [upload_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.dataset_process_rule_id == "rule-2" + assert document.data_source_type == "upload_file" + assert document.data_source_info == '{"upload_file_id": "file-1"}' + assert document.name == "Renamed document" + assert document.indexing_status == "waiting" + assert document.completed_at is None + assert document.processing_started_at is None + assert document.parsing_completed_at is None + assert document.cleaning_completed_at is None + assert document.splitting_completed_at is None + assert document.updated_at == "now" + assert document.created_from == "web" + assert document.doc_form == IndexStructureType.QA_INDEX + assert mock_db.session.commit.call_count == 3 + segment_query.filter_by.return_value.update.assert_called_once() + update_task.delay.assert_called_once_with(document.dataset_id, document.id) + + def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(display_status="available", id="doc-1", dataset_id="dataset-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page")], + ) + ], + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + ): + binding_query = MagicMock() + binding_query.where.return_value.first.return_value = None + mock_db.session.query.return_value = binding_query + + with pytest.raises(ValueError, match="Data source binding not found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_website_crawl_updates_segments_and_dispatches_task(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=["https://example.com"], + only_main_content=False, + ), + ) + ), + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 2 + mock_db.session.query.return_value = segment_query + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.data_source_type == "website_crawl" + assert document.data_source_info == ( + '{"url": "https://example.com", "provider": "firecrawl", "job_id": "job-1", ' + '"only_main_content": false, "mode": "crawl"}' + ) + assert document.name == "" + assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + segment_query.filter_by.return_value.update.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + +class TestDocumentServiceCreateValidation: + """Unit tests for document creation validation helpers.""" + + def test_document_create_args_validate_requires_data_source_or_process_rule(self): + knowledge_config = SimpleNamespace(data_source=None, process_rule=None) + + with pytest.raises(ValueError, match="Data source or Process rule is required"): + DocumentService.document_create_args_validate(knowledge_config) + + def test_document_create_args_validate_delegates_to_sub_validators(self): + knowledge_config = SimpleNamespace(data_source=object(), process_rule=object()) + + with ( + patch.object(DocumentService, "data_source_args_validate") as validate_data_source, + patch.object(DocumentService, "process_rule_args_validate") as validate_process_rule, + ): + DocumentService.document_create_args_validate(knowledge_config) + + validate_data_source.assert_called_once_with(knowledge_config) + validate_process_rule.assert_called_once_with(knowledge_config) + + def test_data_source_args_validate_rejects_invalid_type(self): + knowledge_config = SimpleNamespace( + data_source=SimpleNamespace( + info_list=SimpleNamespace( + data_source_type="bad-source", + file_info_list=None, + notion_info_list=None, + website_info_list=None, + ) + ) + ) + + with pytest.raises(ValueError, match="Data source type is invalid"): + DocumentService.data_source_args_validate(knowledge_config) + + @pytest.mark.parametrize( + ("data_source_type", "field_name", "message"), + [ + ("upload_file", "file_info_list", "File source info is required"), + ("notion_import", "notion_info_list", "Notion source info is required"), + ("website_crawl", "website_info_list", "Website source info is required"), + ], + ) + def test_data_source_args_validate_requires_source_specific_info(self, data_source_type, field_name, message): + info_list = SimpleNamespace( + data_source_type=data_source_type, + file_info_list=object(), + notion_info_list=object(), + website_info_list=object(), + ) + setattr(info_list, field_name, None) + knowledge_config = SimpleNamespace(data_source=SimpleNamespace(info_list=info_list)) + + with pytest.raises(ValueError, match=message): + DocumentService.data_source_args_validate(knowledge_config) + + def test_process_rule_args_validate_clears_rules_for_automatic_mode(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="automatic", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is None + + def test_process_rule_args_validate_deduplicates_rules_and_skips_max_tokens_for_full_doc_hierarchical(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="hierarchical", + rules=Rule( + pre_processing_rules=[ + PreProcessingRule(id="remove_stopwords", enabled=True), + PreProcessingRule(id="remove_stopwords", enabled=False), + ], + segmentation=Segmentation(separator="\n", max_tokens=0), + parent_mode="full-doc", + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is not None + assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1 + assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False + + +class TestDocumentServiceSaveDocumentWithDatasetId: + """Unit tests for non-SQL validation branches in save_document_with_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.current_user", account), + patch.object(DatasetService, "check_doc_form"), + ): + yield account + + def test_save_document_with_dataset_id_requires_file_info_for_upload_source(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)): + with pytest.raises(ValueError, match="File source info is required"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_blocks_batch_upload_for_sandbox_plan(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch( + "services.dataset_service.FeatureService.get_features", + return_value=_make_features(enabled=True, plan=CloudPlan.SANDBOX), + ), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="does not support batch upload"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + check_quota.assert_not_called() + + def test_save_document_with_dataset_id_enforces_batch_upload_limit(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 1), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="batch upload limit of 1"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + check_quota.assert_not_called() + + def test_save_document_with_dataset_id_updates_existing_document_and_data_source_type(self, account_context): + dataset = _make_dataset(data_source_type=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + updated_document = _make_document(document_id="doc-1", batch="batch-existing") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch.object( + DocumentService, "update_document_with_dataset_id", return_value=updated_document + ) as update_document, + ): + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert dataset.data_source_type == "upload_file" + assert documents == [updated_document] + assert batch == "batch-existing" + update_document.assert_called_once_with(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_data_source_for_new_documents(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(data_source=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="Data source is required when creating new documents"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_existing_process_rule_for_custom_mode(self, account_context): + dataset = _make_dataset(latest_process_rule=None) + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule(mode="custom"), + ) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="No process rule found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_rejects_invalid_indexing_technique(self, account_context): + dataset = _make_dataset(indexing_technique=None) + knowledge_config = SimpleNamespace( + doc_form=IndexStructureType.PARAGRAPH_INDEX, + original_document_id=None, + data_source=None, + indexing_technique="broken-technique", + ) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="Indexing technique is invalid"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_returns_empty_for_invalid_process_rule_mode(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1"]) + knowledge_config.process_rule = SimpleNamespace(mode="unsupported-mode", rules=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [] + assert batch == "" + + def test_save_document_with_dataset_id_upload_file_creates_and_reindexes_documents(self, account_context): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + duplicate_document = _make_document(document_id="doc-duplicate", name="existing.txt") + created_document = _make_document(document_id="doc-created", name="new.txt") + upload_file_a = SimpleNamespace(id="file-1", name="existing.txt") + upload_file_b = SimpleNamespace(id="file-2", name="new.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=4), + patch.object(DocumentService, "build_document", return_value=created_document) as build_document, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + patch("services.dataset_service.DuplicateDocumentIndexingTaskProxy") as duplicate_proxy_cls, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [upload_file_a, upload_file_b] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [duplicate_document] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert documents == [duplicate_document, created_document] + assert batch == "20260101010101100023" + assert duplicate_document.dataset_process_rule_id == "rule-1" + assert duplicate_document.updated_at == "now" + assert duplicate_document.batch == batch + assert duplicate_document.indexing_status == "waiting" + build_document.assert_called_once_with( + dataset, + "rule-1", + "upload_file", + IndexStructureType.PARAGRAPH_INDEX, + "English", + {"upload_file_id": "file-2"}, + "web", + 4, + account_context, + "new.txt", + batch, + ) + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-created"]) + document_proxy_cls.return_value.delay.assert_called_once() + duplicate_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-duplicate"]) + duplicate_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_notion_import_truncates_names_and_cleans_removed_pages( + self, account_context + ): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + notion_page_name = "a" * 300 + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-keep", page_name="Keep page", type="page"), + NotionPage( + page_id="page-new", + page_name=notion_page_name, + page_icon=NotionIcon(type="emoji", emoji="page"), + type="page", + ), + ], + ) + ], + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + existing_keep = _make_document(document_id="doc-keep") + existing_keep.data_source_info = json.dumps({"notion_page_id": "page-keep"}) + existing_remove = _make_document(document_id="doc-remove") + existing_remove.data_source_info = json.dumps({"notion_page_id": "page-remove"}) + created_document = _make_document(document_id="doc-new") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document) as build_document, + patch("services.dataset_service.clean_notion_document_task") as clean_task, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + ): + mock_redis.lock.return_value = _make_lock_context() + notion_documents_query = MagicMock() + notion_documents_query.filter_by.return_value.all.return_value = [existing_keep, existing_remove] + mock_db.session.query.return_value = notion_documents_query + + documents, _ = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert created_document in documents + assert len(build_document.call_args.args[9]) == 255 + clean_task.delay.assert_called_once_with(["doc-remove"], dataset.id) + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-new"]) + document_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_website_crawl_truncates_long_urls(self, account_context): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + long_url = "https://example.com/" + ("a" * 260) + short_url = "https://example.com/short" + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=[long_url, short_url], + only_main_content=True, + ), + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + first_document = _make_document(document_id="doc-1") + second_document = _make_document(document_id="doc-2") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=2), + patch.object( + DocumentService, + "build_document", + side_effect=[first_document, second_document], + ) as build_document, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + ): + mock_redis.lock.return_value = _make_lock_context() + + documents, _ = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert documents == [first_document, second_document] + assert build_document.call_args_list[0].args[9] == long_url[:200] + "..." + assert build_document.call_args_list[1].args[9] == short_url + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-1", "doc-2"]) + document_proxy_cls.return_value.delay.assert_called_once() + + +class TestDocumentServiceBatchUpdateStatus: + """Unit tests for batch_update_document_status orchestration and helper branches.""" + + def test_prepare_disable_update_requires_completed_document(self): + document = _make_document(indexing_status="waiting") + document.completed_at = None + + with pytest.raises(DocumentIndexingError, match="is not completed"): + DocumentService._prepare_disable_update(document, user=SimpleNamespace(id="user-1"), now="now") + + def test_prepare_archive_update_sets_async_task_for_enabled_document(self): + document = _make_document(enabled=True, archived=False) + + result = DocumentService._prepare_archive_update(document, user=SimpleNamespace(id="user-1"), now="now") + + assert result is not None + assert result["updates"]["archived"] is True + assert result["set_cache"] is True + assert result["async_task"]["args"] == [document.id] + + def test_prepare_unarchive_update_sets_async_task_for_enabled_document(self): + document = _make_document(enabled=True, archived=True) + + result = DocumentService._prepare_unarchive_update(document, now="now") + + assert result is not None + assert result["updates"]["archived"] is False + assert result["set_cache"] is True + assert result["async_task"]["args"] == [document.id] + + def test_batch_update_document_status_rejects_indexing_documents(self): + dataset = _make_dataset() + document = _make_document(name="Busy document") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + ): + mock_redis.get.return_value = "1" + + with pytest.raises(DocumentIndexingError, match="Busy document is being indexed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "archive", SimpleNamespace(id="user-1") + ) + + mock_db.session.commit.assert_not_called() + + def test_batch_update_document_status_rolls_back_when_commit_fails(self): + dataset = _make_dataset() + document = _make_document(enabled=False) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + ): + mock_redis.get.return_value = None + mock_db.session.commit.side_effect = RuntimeError("commit failed") + + with pytest.raises(RuntimeError, match="commit failed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "enable", SimpleNamespace(id="user-1") + ) + + mock_db.session.rollback.assert_called_once() + + def test_batch_update_document_status_raises_async_task_error_after_commit(self): + dataset = _make_dataset() + document = _make_document(enabled=False) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.add_document_to_index_task") as add_task, + ): + mock_redis.get.return_value = None + add_task.delay.side_effect = RuntimeError("task failed") + + with pytest.raises(RuntimeError, match="task failed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "enable", SimpleNamespace(id="user-1") + ) + + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + + +class TestDocumentServiceTenantAndUpdateEdges: + """Unit tests for tenant-count and update edge cases.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_get_tenant_documents_count_returns_query_count(self, account_context): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.count.return_value = 12 + + result = DocumentService.get_tenant_documents_count() + + assert result == 12 + mock_db.session.query.return_value.where.return_value.count.assert_called_once() + + def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="automatic", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + ) + created_process_rule = SimpleNamespace(id="rule-2") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 1 + mock_db.session.query.side_effect = [upload_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.dataset_process_rule_id == "rule-2" + assert document.name == "upload.txt" + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + assert mock_db.session.commit.call_count == 3 + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + def test_update_document_with_dataset_id_requires_upload_file_info(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="upload_file")), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="No file info list found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_raises_when_upload_file_is_missing(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(FileNotExistsError): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_requires_notion_info_list(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="notion_import")), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_notion_import_updates_page_info(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page"), + NotionPage(page_id="page-2", page_name="Page 2", page_icon=None, type="database"), + ], + ) + ], + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + binding_query = MagicMock() + binding_query.where.return_value.first.return_value = SimpleNamespace(id="binding-1") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 1 + mock_db.session.query.side_effect = [binding_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.data_source_type == "notion_import" + assert document.name == "" + assert document.data_source_info == json.dumps( + { + "credential_id": "credential-1", + "notion_workspace_id": "workspace-1", + "notion_page_id": "page-2", + "notion_page_icon": None, + "type": "database", + } + ) + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + +class TestDocumentServiceSaveWithoutDatasetBilling: + """Unit tests for batch-count and quota branches in save_document_without_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_save_document_without_dataset_id_counts_notion_pages_for_quota(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page"), + NotionPage(page_id="page-2", page_name="Page 2", page_icon=None, type="page"), + ], + ), + NotionInfo( + credential_id="credential-2", + workspace_id="workspace-2", + pages=[NotionPage(page_id="page-3", page_name="Page 3", page_icon=None, type="page")], + ), + ], + ) + ), + ) + created_dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="", description=None) + features = _make_features(enabled=True) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "10"), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ), + patch.object( + DocumentService, + "save_document_with_dataset_id", + return_value=([SimpleNamespace(name="Doc")], "batch-1"), + ), + patch("services.dataset_service.db"), + ): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_called_once_with(3, features) + + def test_save_document_without_dataset_id_enforces_batch_limit_for_website_urls(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=["https://example.com/a", "https://example.com/b"], + only_main_content=True, + ), + ) + ), + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "1"), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="batch upload limit of 1"): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_not_called() + + +class TestDocumentServiceEstimateValidation: + """Unit tests for estimate_args_validate branches.""" + + def test_estimate_args_validate_rejects_missing_info_list(self): + with pytest.raises(ValueError, match="Data source info is required"): + DocumentService.estimate_args_validate({}) + + def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": {"mode": "automatic", "rules": {"ignored": True}}, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"] == {} + + def test_estimate_args_validate_rejects_unknown_pre_processing_rule_id(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "unknown", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + }, + } + + with pytest.raises(ValueError, match="pre_processing_rules id is invalid"): + DocumentService.estimate_args_validate(args) + + def test_estimate_args_validate_deduplicates_rules_for_custom_mode(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_stopwords", "enabled": True}, + {"id": "remove_stopwords", "enabled": False}, + ], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + }, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"]["pre_processing_rules"] == [{"id": "remove_stopwords", "enabled": False}] + + def test_estimate_args_validate_requires_summary_index_provider_name(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + "summary_index_setting": {"enable": True, "model_name": "summary-model"}, + }, + } + + with pytest.raises(ValueError, match="Summary index model provider name is required"): + DocumentService.estimate_args_validate(args) + + +class TestDocumentServiceSaveDocumentAdditionalBranches: + """Additional unit tests for dataset bootstrap and process-rule branches.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.current_user", account), + patch.object(DatasetService, "check_doc_form"), + ): + yield account + + def test_save_document_with_dataset_id_initializes_high_quality_dataset_from_default_embedding_model( + self, account_context + ): + dataset = _make_dataset(data_source_type=None, indexing_technique=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = None + knowledge_config.embedding_model_provider = None + updated_document = _make_document(batch="batch-existing") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ) as get_binding, + patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document), + ): + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace( + model_name="default-embedding", + provider="default-provider", + ) + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [updated_document] + assert batch == "batch-existing" + assert dataset.data_source_type == "upload_file" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "default-embedding" + assert dataset.embedding_model_provider == "default-provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.retrieval_model == { + "search_method": "semantic_search", + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 4, + "score_threshold_enabled": False, + } + get_binding.assert_called_once_with("default-provider", "default-embedding") + + def test_save_document_with_dataset_id_uses_explicit_embedding_and_retrieval_model(self, account_context): + dataset = _make_dataset(indexing_technique=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = "explicit-model" + knowledge_config.embedding_model_provider = "explicit-provider" + knowledge_config.retrieval_model = RetrievalModel( + search_method="semantic_search", + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + top_k=7, + score_threshold_enabled=True, + score_threshold=0.3, + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ) as get_binding, + patch.object(DocumentService, "update_document_with_dataset_id", return_value=_make_document()), + ): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called() + get_binding.assert_called_once_with("explicit-provider", "explicit-model") + assert dataset.embedding_model == "explicit-model" + assert dataset.embedding_model_provider == "explicit-provider" + assert dataset.retrieval_model == knowledge_config.retrieval_model.model_dump() + + def test_save_document_with_dataset_id_creates_custom_process_rule_for_new_upload_document(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule( + mode="custom", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + ) + created_process_rule = SimpleNamespace(id="rule-custom") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=3), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [created_document] + assert batch == "20260101010101100023" + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "custom", + "rules": knowledge_config.process_rule.rules.model_dump_json(), + "created_by": "user-1", + } + document_proxy_cls.assert_called_once_with("tenant-1", "dataset-1", ["doc-created"]) + document_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_creates_automatic_process_rule_for_new_upload_document( + self, account_context + ): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule(mode="automatic"), + ) + created_process_rule = SimpleNamespace(id="rule-auto") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + assert mock_db.session.flush.call_count >= 2 + + def test_save_document_with_dataset_id_creates_fallback_automatic_process_rule_when_latest_is_missing( + self, account_context + ): + dataset = _make_dataset(latest_process_rule=None) + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1"], process_rule=None) + created_process_rule = SimpleNamespace(id="rule-fallback") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + + def test_save_document_with_dataset_id_raises_when_upload_file_lookup_is_incomplete(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + mock_db.session.query.return_value = upload_query + + with pytest.raises(FileNotExistsError, match="One or more files not found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_notion_info_list_for_notion_import(self, account_context): + dataset = _make_dataset() + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="notion_import")), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch.object(DocumentService, "get_documents_position", return_value=1), + ): + mock_redis.lock.return_value = _make_lock_context() + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=SimpleNamespace(id="rule-1"), + ) + + def test_save_document_with_dataset_id_requires_website_info_list_for_website_crawl(self, account_context): + dataset = _make_dataset() + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="website_crawl")), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch.object(DocumentService, "get_documents_position", return_value=1), + ): + mock_redis.lock.return_value = _make_lock_context() + with pytest.raises(ValueError, match="No website info list found"): + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=SimpleNamespace(id="rule-1"), + ) diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index bd226f7536..9a513c3fe6 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -70,16 +71,16 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch # Minimal knowledge_config stub that satisfies pre-lock code info_list = types.SimpleNamespace(data_source_type="upload_file") data_source = types.SimpleNamespace(info_list=info_list) knowledge_config = types.SimpleNamespace( - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model=None, embedding_model_provider=None, retrieval_model=None, @@ -125,13 +126,13 @@ def test_add_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # skip embedding/token calculation branch + dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch document = create_autospec(Document, instance=True) document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX # Minimal args required by add_segment args = { @@ -168,10 +169,10 @@ def test_multi_create_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # again, skip high_quality path + dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path document = create_autospec(Document, instance=True) document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py new file mode 100644 index 0000000000..2f8ae14a8e --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -0,0 +1,1017 @@ +"""Unit tests for SegmentService behaviors in dataset_service.""" + +from .dataset_service_test_helpers import ( + Account, + ChildChunk, + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + ChildChunkUpdateArgs, + DocumentSegment, + IndexStructureType, + MagicMock, + ModelType, + SegmentService, + SegmentUpdateArgs, + SimpleNamespace, + _make_child_chunk, + _make_dataset, + _make_document, + _make_lock_context, + _make_segment, + create_autospec, + patch, + pytest, +) + + +class TestSegmentServiceChildChunks: + """Unit tests for child-chunk CRUD helpers.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_create_child_chunk_assigns_next_position_and_commits(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.lock.return_value = _make_lock_context() + mock_db.session.query.return_value.where.return_value.scalar.return_value = 2 + + child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset) + + assert isinstance(child_chunk, ChildChunk) + assert child_chunk.position == 3 + assert child_chunk.index_node_id == "node-1" + assert child_chunk.index_node_hash == "hash-1" + assert child_chunk.word_count == len("child content") + mock_db.session.add.assert_called_once_with(child_chunk) + vector_service.create_child_chunk_vector.assert_called_once_with(child_chunk, dataset) + mock_db.session.commit.assert_called_once() + + def test_create_child_chunk_rolls_back_and_raises_on_vector_failure(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.lock.return_value = _make_lock_context() + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.create_child_chunk("child content", segment, document, dataset) + + mock_db.session.rollback.assert_called_once() + mock_db.session.commit.assert_not_called() + + def test_update_child_chunks_updates_deletes_and_creates_records(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + existing_a = ChildChunk( + id="child-a", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=1, + content="old content", + word_count=11, + created_by="user-1", + ) + existing_b = ChildChunk( + id="child-b", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=2, + content="remove me", + word_count=9, + created_by="user-1", + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-new"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-new"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_db.session.scalars.return_value.all.return_value = [existing_a, existing_b] + + result = SegmentService.update_child_chunks( + [ + ChildChunkUpdateArgs(id="child-a", content="updated content"), + ChildChunkUpdateArgs(content="brand new"), + ], + segment, + document, + dataset, + ) + + assert [chunk.position for chunk in result] == [1, 3] + assert existing_a.content == "updated content" + assert existing_a.updated_by == account_context.id + assert existing_a.updated_at == "now" + mock_db.session.bulk_save_objects.assert_called_once_with([existing_a]) + mock_db.session.delete.assert_called_once_with(existing_b) + new_chunk = result[1] + assert isinstance(new_chunk, ChildChunk) + assert new_chunk.position == 3 + assert new_chunk.index_node_id == "node-new" + vector_service.update_child_chunk_vector.assert_called_once_with( + [new_chunk], [existing_a], [existing_b], dataset + ) + mock_db.session.commit.assert_called_once() + + def test_update_child_chunks_rolls_back_on_vector_failure(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + existing_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_db.session.scalars.return_value.all.return_value = [existing_chunk] + vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.update_child_chunks( + [ChildChunkUpdateArgs(id="child-a", content="updated content")], + segment, + document, + dataset, + ) + + mock_db.session.rollback.assert_called_once() + + def test_update_child_chunk_updates_vector_and_commits(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + result = SegmentService.update_child_chunk( + "new content", child_chunk, _make_segment(), _make_document(), dataset + ) + + assert result is child_chunk + assert child_chunk.content == "new content" + assert child_chunk.word_count == len("new content") + assert child_chunk.updated_by == "user-1" + assert child_chunk.updated_at == "now" + mock_db.session.add.assert_called_once_with(child_chunk) + vector_service.update_child_chunk_vector.assert_called_once_with([], [child_chunk], [], dataset) + mock_db.session.commit.assert_called_once() + + def test_delete_child_chunk_raises_delete_index_error_on_vector_failure(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + vector_service.delete_child_chunk_vector.side_effect = RuntimeError("delete failed") + + with pytest.raises(ChildChunkDeleteIndexError, match="delete failed"): + SegmentService.delete_child_chunk(child_chunk, dataset) + + mock_db.session.delete.assert_called_once_with(child_chunk) + mock_db.session.rollback.assert_called_once() + + +class TestSegmentServiceQueries: + """Unit tests for child-chunk and segment query helpers.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_get_child_chunks_applies_keyword_filter_and_paginate(self, account_context): + paginated = SimpleNamespace(items=["chunk"], total=1) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like, + ): + mock_db.paginate.return_value = paginated + + result = SegmentService.get_child_chunks( + segment_id="segment-1", + document_id="doc-1", + dataset_id="dataset-1", + page=2, + limit=10, + keyword="needle", + ) + + assert result is paginated + escape_like.assert_called_once_with("needle") + mock_db.paginate.assert_called_once() + + def test_get_child_chunk_by_id_returns_only_child_chunk_instances(self): + child_chunk = _make_child_chunk() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + + assert result is child_chunk + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + + assert result is None + + def test_get_segments_uses_status_and_keyword_filters(self): + paginated = SimpleNamespace(items=["segment"], total=1) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like, + ): + mock_db.paginate.return_value = paginated + + items, total = SegmentService.get_segments( + document_id="doc-1", + tenant_id="tenant-1", + status_list=["completed"], + keyword="needle", + page=1, + limit=20, + ) + + assert items == ["segment"] + assert total == 1 + escape_like.assert_called_once_with("needle") + mock_db.paginate.assert_called_once() + + def test_get_segment_by_id_returns_only_document_segment_instances(self): + segment = DocumentSegment( + id="segment-1", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + position=1, + content="segment", + word_count=7, + tokens=2, + created_by="user-1", + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = segment + result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + + assert result is segment + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + + assert result is None + + def test_get_segments_by_document_and_dataset_returns_scalars_result(self): + segment = DocumentSegment( + id="segment-1", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + position=1, + content="segment", + word_count=7, + tokens=2, + created_by="user-1", + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [segment] + + result = SegmentService.get_segments_by_document_and_dataset( + document_id="doc-1", + dataset_id="dataset-1", + status="completed", + enabled=True, + ) + + assert result == [segment] + mock_db.session.scalars.assert_called_once() + + +class TestSegmentServiceValidation: + """Unit tests for segment-create argument validation.""" + + def test_segment_create_args_validate_requires_answer_for_qa_model(self): + document = _make_document(doc_form=IndexStructureType.QA_INDEX) + + with pytest.raises(ValueError, match="Answer is required"): + SegmentService.segment_create_args_validate({"content": "question"}, document) + + def test_segment_create_args_validate_requires_non_empty_content(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + + with pytest.raises(ValueError, match="Content is empty"): + SegmentService.segment_create_args_validate({"content": " "}, document) + + def test_segment_create_args_validate_enforces_attachment_limit(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + args = {"content": "hello", "attachment_ids": ["a-1", "a-2"]} + + with patch("services.dataset_service.dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT", 1): + with pytest.raises(ValueError, match="Exceeded maximum attachment limit of 1"): + SegmentService.segment_create_args_validate(args, document) + + def test_segment_create_args_validate_requires_attachment_ids_list(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + + with pytest.raises(ValueError, match="Attachment IDs is invalid"): + SegmentService.segment_create_args_validate({"content": "hello", "attachment_ids": "bad-type"}, document) + + +class TestSegmentServiceMutations: + """Unit tests for segment create, update, delete, and bulk status flows.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_create_segment_creates_bindings_and_marks_segment_error_on_vector_failure(self, account_context): + dataset = _make_dataset(indexing_technique="economy") + document = _make_document( + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + doc_form=IndexStructureType.QA_INDEX, + word_count=0, + ) + refreshed_segment = SimpleNamespace(id="segment-1") + args = { + "content": "question", + "answer": "answer", + "keywords": ["kw-1"], + "attachment_ids": ["att-1", "att-2"], + } + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.lock.return_value = _make_lock_context() + + max_position_query = MagicMock() + max_position_query.where.return_value.scalar.return_value = 2 + refresh_query = MagicMock() + refresh_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [max_position_query, refresh_query] + + def add_side_effect(obj): + if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None: + obj.id = "segment-1" + + mock_db.session.add.side_effect = add_side_effect + vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") + + result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + + created_segment = vector_service.create_segments_vector.call_args.args[1][0] + attachment_bindings = [ + call.args[0] + for call in mock_db.session.add.call_args_list + if call.args and call.args[0].__class__.__name__ == "SegmentAttachmentBinding" + ] + + assert result is refreshed_segment + assert created_segment.position == 3 + assert created_segment.answer == "answer" + assert created_segment.word_count == len("question") + len("answer") + assert created_segment.status == "error" + assert created_segment.enabled is False + assert created_segment.error == "vector failed" + assert document.word_count == len("question") + len("answer") + assert len(attachment_bindings) == 2 + assert {binding.attachment_id for binding in attachment_bindings} == {"att-1", "att-2"} + assert mock_db.session.commit.call_count == 3 + + def test_multi_create_segment_high_quality_marks_segments_error_when_vector_creation_fails(self, account_context): + dataset = _make_dataset(indexing_technique="high_quality") + document = _make_document( + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + doc_form=IndexStructureType.QA_INDEX, + word_count=5, + ) + segments = [ + {"content": "question-1", "answer": "answer-1", "keywords": ["k1"]}, + {"content": "question-2", "answer": "answer-2"}, + ] + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.side_effect = [[11], [13]] + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", side_effect=["hash-1", "hash-2"]), + patch("services.dataset_service.uuid.uuid4", side_effect=["node-1", "node-2"]), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.lock.return_value = _make_lock_context() + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 + vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") + + result = SegmentService.multi_create_segment(segments, document, dataset) + + assert len(result) == 2 + assert [segment.position for segment in result] == [2, 3] + assert [segment.tokens for segment in result] == [11, 13] + assert all(segment.status == "error" for segment in result) + assert all(segment.enabled is False for segment in result) + assert all(segment.error == "vector failed" for segment in result) + assert document.word_count == 5 + sum(len(item["content"]) + len(item["answer"]) for item in segments) + vector_service.create_segments_vector.assert_called_once_with( + [["k1"], None], result, dataset, document.doc_form + ) + mock_db.session.commit.assert_called_once() + + def test_update_segment_disables_enabled_segment_and_dispatches_index_cleanup(self, account_context): + segment = _make_segment(enabled=True) + document = _make_document() + dataset = _make_dataset() + args = SegmentUpdateArgs(enabled=False) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.disable_segment_from_index_task") as disable_task, + ): + mock_redis.get.return_value = None + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is segment + assert segment.enabled is False + assert segment.disabled_at == "now" + assert segment.disabled_by == account_context.id + mock_db.session.add.assert_called_once_with(segment) + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_indexing", 600, 1) + disable_task.delay.assert_called_once_with(segment.id) + + def test_update_segment_rejects_updates_for_disabled_segment(self, account_context): + segment = _make_segment(enabled=False) + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = None + + with pytest.raises(ValueError, match="Can't update disabled segment"): + SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + + def test_update_segment_rejects_when_indexing_cache_exists(self, account_context): + segment = _make_segment(enabled=True) + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="Segment is indexing"): + SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + + def test_update_segment_updates_keywords_for_same_content_segment(self, account_context): + segment = _make_segment(content="same content", keywords=["old"]) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=20) + dataset = _make_dataset() + refreshed_segment = SimpleNamespace(id=segment.id) + args = SegmentUpdateArgs(content="same content", keywords=["new"]) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + assert segment.keywords == ["new"] + vector_service.update_segment_vector.assert_called_once_with(["new"], segment, dataset) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_regenerates_child_chunks_and_updates_manual_summary(self, account_context): + segment = _make_segment(content="same content", word_count=len("same content")) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=20, + ) + dataset = _make_dataset(indexing_technique="high_quality") + refreshed_segment = SimpleNamespace(id=segment.id) + processing_rule = SimpleNamespace(id=document.dataset_process_rule_id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model_instance = object() + args = SegmentUpdateArgs( + content="same content", + regenerate_child_chunks=True, + summary="new summary", + ) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance + + processing_rule_query = MagicMock() + processing_rule_query.where.return_value.first.return_value = processing_rule + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + vector_service.generate_child_chunks.assert_called_once_with( + segment, + document, + dataset, + embedding_model_instance, + processing_rule, + True, + ) + update_summary.assert_called_once_with(segment, dataset, "new summary") + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_auto_regenerates_summary_after_content_change(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.summary_index_setting = {"enable": True} + refreshed_segment = SimpleNamespace(id=segment.id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [9] + args = SegmentUpdateArgs(content="new content", keywords=["kw-1"]) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch( + "services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary" + ) as generate_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + assert segment.content == "new content" + assert segment.index_node_hash == "hash-1" + assert segment.tokens == 9 + assert document.word_count == 18 + vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset) + generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_regenerates_summary_when_manual_summary_is_unchanged(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.summary_index_setting = {"enable": True} + refreshed_segment = SimpleNamespace(id=segment.id) + existing_summary = SimpleNamespace(summary_content="same summary") + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [7] + args = SegmentUpdateArgs(content="new text", summary="same summary") + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-2"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch( + "services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary" + ) as generate_summary, + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) + update_summary.assert_not_called() + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_delete_segment_removes_index_and_updates_document_word_count(self): + segment = _make_segment(word_count=4, index_node_id="parent-node") + document = _make_document(word_count=10) + dataset = _make_dataset() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.delete_segment_from_index_task") as delete_task, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)] + + SegmentService.delete_segment(segment, document, dataset) + + assert document.word_count == 6 + mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_delete_indexing", 600, 1) + delete_task.delay.assert_called_once_with( + ["parent-node"], + dataset.id, + document.id, + [segment.id], + ["child-1", "child-2"], + ) + mock_db.session.delete.assert_called_once_with(segment) + mock_db.session.add.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + + def test_delete_segment_rejects_when_delete_is_already_in_progress(self): + segment = _make_segment() + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="Segment is deleting"): + SegmentService.delete_segment(segment, document, dataset) + + def test_delete_segments_removes_records_and_clamps_document_word_count(self): + dataset = _make_dataset() + document = _make_document(word_count=3) + current_user = SimpleNamespace(current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.delete_segment_from_index_task") as delete_task, + ): + segments_query = MagicMock() + segments_query.with_entities.return_value.where.return_value.all.return_value = [ + ("node-1", "segment-1", 2), + ("node-2", "segment-2", 5), + ] + child_query = MagicMock() + child_query.where.return_value.all.return_value = [("child-1",)] + delete_query = MagicMock() + delete_query.where.return_value.delete.return_value = 2 + mock_db.session.query.side_effect = [segments_query, child_query, delete_query] + + SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset) + + assert document.word_count == 0 + mock_db.session.add.assert_called_once_with(document) + delete_task.delay.assert_called_once_with( + ["node-1", "node-2"], + dataset.id, + document.id, + ["segment-1", "segment-2"], + ["child-1"], + ) + delete_query.where.return_value.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_update_segments_status_enables_only_segments_without_indexing_cache(self): + dataset = _make_dataset() + document = _make_document() + segment_a = _make_segment(segment_id="segment-a", enabled=False) + segment_b = _make_segment(segment_id="segment-b", enabled=False) + current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.enable_segments_to_index_task") as enable_task, + ): + mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] + mock_redis.get.side_effect = [None, "1"] + + SegmentService.update_segments_status(["segment-a", "segment-b"], "enable", dataset, document) + + assert segment_a.enabled is True + assert segment_a.disabled_at is None + assert segment_a.disabled_by is None + assert segment_b.enabled is False + mock_db.session.add.assert_called_once_with(segment_a) + mock_db.session.commit.assert_called_once() + enable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id) + + def test_update_segments_status_disables_only_segments_without_indexing_cache(self): + dataset = _make_dataset() + document = _make_document() + segment_a = _make_segment(segment_id="segment-a", enabled=True) + segment_b = _make_segment(segment_id="segment-b", enabled=True) + current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.disable_segments_from_index_task") as disable_task, + ): + mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] + mock_redis.get.side_effect = [None, "1"] + + SegmentService.update_segments_status(["segment-a", "segment-b"], "disable", dataset, document) + + assert segment_a.enabled is False + assert segment_a.disabled_at == "now" + assert segment_a.disabled_by == current_user.id + assert segment_b.enabled is True + mock_db.session.add.assert_called_once_with(segment_a) + mock_db.session.commit.assert_called_once() + disable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id) + + +class TestSegmentServiceChildChunkTailHelpers: + """Unit tests for the remaining child-chunk helper branches.""" + + def test_update_child_chunk_rolls_back_on_vector_failure(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.update_child_chunk( + "new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset + ) + + mock_db.session.rollback.assert_called_once() + mock_db.session.commit.assert_not_called() + + def test_delete_child_chunk_commits_after_successful_vector_delete(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + SegmentService.delete_child_chunk(child_chunk, dataset) + + mock_db.session.delete.assert_called_once_with(child_chunk) + vector_service.delete_child_chunk_vector.assert_called_once_with(child_chunk, dataset) + mock_db.session.commit.assert_called_once() + + +class TestSegmentServiceAdditionalRegenerationBranches: + """Additional unit tests for segment update and regeneration edge cases.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_update_segment_same_content_updates_answer_and_document_word_count_for_qa_segments(self, account_context): + segment = _make_segment(content="question", word_count=8) + document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=20) + dataset = _make_dataset() + refreshed_segment = SimpleNamespace(id=segment.id) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="question", answer="new answer"), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + assert segment.answer == "new answer" + assert segment.word_count == len("question") + len("new answer") + assert document.word_count == 20 + (len("question") + len("new answer") - 8) + vector_service.update_segment_vector.assert_not_called() + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_content_change_uses_answer_when_counting_tokens_for_qa_segments(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + refreshed_segment = SimpleNamespace(id=segment.id) + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [21] + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-qa"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = None + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new questionnew answer"]) + assert segment.answer == "new answer" + assert segment.tokens == 21 + assert segment.word_count == len("new question") + len("new answer") + vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_content_change_parent_child_uses_default_embedding_and_ignores_summary_failures( + self, account_context + ): + segment = _make_segment(content="old", word_count=3) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=10, + ) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.embedding_model_provider = None + refreshed_segment = SimpleNamespace(id=segment.id) + processing_rule = SimpleNamespace(id=document.dataset_process_rule_id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model_instance = object() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-parent"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance + update_summary.side_effect = RuntimeError("summary failed") + + processing_rule_query = MagicMock() + processing_rule_query.where.return_value.first.return_value = processing_rule + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.TEXT_EMBEDDING, + ) + vector_service.generate_child_chunks.assert_called_once_with( + segment, + document, + dataset, + embedding_model_instance, + processing_rule, + True, + ) + update_summary.assert_called_once_with(segment, dataset, "new summary") + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_same_content_parent_child_marks_segment_error_for_non_high_quality_dataset( + self, account_context + ): + segment = _make_segment(content="same content", word_count=len("same content")) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=20, + ) + dataset = _make_dataset(indexing_technique="economy") + refreshed_segment = SimpleNamespace(id=segment.id) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="same content", regenerate_child_chunks=True), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + assert segment.enabled is False + assert segment.disabled_at == "now" + assert segment.status == "error" + assert segment.error == "The knowledge base index technique is not high quality!" + vector_service.update_multimodel_vector.assert_not_called() diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 105ef7ba48..da93239600 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from dify_graph.model_runtime.entities.provider_entities import FormType +from graphon.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py deleted file mode 100644 index a7e1a011f6..0000000000 --- a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Unit tests for archived workflow run deletion service. -""" - -from unittest.mock import MagicMock, patch - - -class TestArchivedWorkflowRunDeletion: - def test_delete_by_run_id_calls_delete_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion() - repo = MagicMock() - repo.get_archived_run_ids.return_value = {"run-1"} - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - session = MagicMock() - session.get.return_value = run - - session_maker = MagicMock() - session_maker.return_value.__enter__.return_value = session - session_maker.return_value.__exit__.return_value = None - mock_db = MagicMock() - mock_db.engine = MagicMock() - - with ( - patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), - patch( - "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", - return_value=session_maker, - autospec=True, - ), - patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True), - patch.object( - deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True - ) as mock_delete_run, - ): - result = deleter.delete_by_run_id("run-1") - - assert result.success is True - mock_delete_run.assert_called_once_with(run) - - def test_delete_run_dry_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion(dry_run=True) - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo: - result = deleter._delete_run(run) - - assert result.success is True - mock_get_repo.assert_not_called() diff --git a/api/tests/unit_tests/services/test_document_service_display_status.py b/api/tests/unit_tests/services/test_document_service_display_status.py deleted file mode 100644 index cb2e2940c8..0000000000 --- a/api/tests/unit_tests/services/test_document_service_display_status.py +++ /dev/null @@ -1,8 +0,0 @@ -from services.dataset_service import DocumentService - - -def test_normalize_display_status_alias_mapping(): - assert DocumentService.normalize_display_status("ACTIVE") == "available" - assert DocumentService.normalize_display_status("enabled") == "available" - assert DocumentService.normalize_display_status("archived") == "archived" - assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py deleted file mode 100644 index a3b1f46436..0000000000 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ /dev/null @@ -1,841 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, DefaultEndUserSessionID, EndUser -from services.end_user_service import EndUserService - - -class TestEndUserServiceFactory: - """Factory class for creating test data and mock objects for end user service tests.""" - - @staticmethod - def create_app_mock( - app_id: str = "app-123", - tenant_id: str = "tenant-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock(spec=App) - app.id = app_id - app.tenant_id = tenant_id - app.name = name - return app - - @staticmethod - def create_end_user_mock( - user_id: str = "user-789", - tenant_id: str = "tenant-456", - app_id: str = "app-123", - session_id: str = "session-001", - type: InvokeFrom = InvokeFrom.SERVICE_API, - is_anonymous: bool = False, - ) -> MagicMock: - """Create a mock EndUser object.""" - end_user = MagicMock(spec=EndUser) - end_user.id = user_id - end_user.tenant_id = tenant_id - end_user.app_id = app_id - end_user.session_id = session_id - end_user.type = type - end_user.is_anonymous = is_anonymous - end_user.external_user_id = session_id - return end_user - - -class TestEndUserServiceGetEndUserById: - """Unit tests for EndUserService.get_end_user_by_id method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): - """Test successful retrieval of end user by ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = mock_end_user - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result == mock_end_user - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.first.assert_called_once() - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): - """Test retrieval of non-existent end user returns None.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result is None - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): - """Test that query parameters are correctly applied.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - # Verify the where clause was called with the correct conditions - call_args = mock_query.where.call_args[0] - assert len(call_args) == 3 - # Check that the conditions match the expected filters - # (We can't easily test the exact conditions without importing SQLAlchemy) - - -class TestEndUserServiceGetOrCreateEndUser: - """Unit tests for EndUserService.get_or_create_end_user method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user with specific user_id.""" - # Arrange - app_mock = factory.create_app_mock() - user_id = "user-123" - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, user_id) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id - ) - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user without user_id (None).""" - # Arrange - app_mock = factory.create_app_mock() - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, None) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None - ) - - -class TestEndUserServiceGetOrCreateEndUserByType: - """ - Unit tests for EndUserService.get_or_create_end_user_by_type method. - - This test suite covers: - - Creating end users with different InvokeFrom types - - Type migration for legacy users - - Query ordering and prioritization - - Session management - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): - """Test creating a new end user with specific user_id.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - # Verify new EndUser was created with correct parameters - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - added_user = mock_session.add.call_args[0][0] - assert added_user.tenant_id == tenant_id - assert added_user.app_id == app_id - assert added_user.type == type_enum - assert added_user.session_id == user_id - assert added_user.external_user_id == user_id - assert added_user._is_anonymous is False - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): - """Test creating a new end user with default session ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = None - type_enum = InvokeFrom.WEB_APP - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): - """Test retrieving existing user with same type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - mock_session.add.assert_not_called() - mock_session.commit.assert_not_called() - mock_logger.info.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): - """Test upgrading existing user with different type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - old_type = InvokeFrom.WEB_APP - new_type = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - assert existing_user.type == new_type - mock_session.commit.assert_called_once() - mock_logger.info.assert_called_once() - logger_call_args = mock_logger.info.call_args[0] - assert "Upgrading legacy EndUser" in logger_call_args[0] - # The old and new types are passed as separate arguments - assert mock_logger.info.call_args[0][1] == existing_user.id - assert mock_logger.info.call_args[0][2] == old_type - assert mock_logger.info.call_args[0][3] == new_type - assert mock_logger.info.call_args[0][4] == user_id - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): - """Test that query ordering prioritizes exact type matches.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - target_type = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - mock_query.order_by.assert_called_once() - # Verify that case statement is used for ordering - order_by_call = mock_query.order_by.call_args[0][0] - # The exact structure depends on SQLAlchemy's case implementation - # but we can verify it was called - - # Test 10: Session context manager properly closes - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_session_context_manager_closes(self, mock_db, mock_session_class, factory): - """Test that Session context manager is properly used.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_id=app_id, - user_id=user_id, - ) - - # Assert - # Verify context manager was entered and exited - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): - """Test that all InvokeFrom enum values are supported.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.type == invoke_type - - -class TestEndUserServiceCreateEndUserBatch: - """Unit tests for EndUserService.create_end_user_batch method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): - """Test batch creation with empty app_ids list.""" - # Arrange - tenant_id = "tenant-123" - app_ids: list[str] = [] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert result == {} - mock_session_class.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_default_session_id(self, mock_db, mock_session_class): - """Test batch creation with empty user_id (uses default session).""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - for app_id, end_user in result.items(): - assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): - """Test that duplicate app_ids are deduplicated while preserving order.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - # Should have 3 unique app_ids in original order - assert len(result) == 3 - assert "app-456" in result - assert "app-789" in result - assert "app-123" in result - - # Verify the order is preserved - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 3 - assert added_users[0].app_id == "app-456" - assert added_users[1].app_id == "app-789" - assert added_users[2].app_id == "app-123" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation when all users already exist.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - assert result["app-456"] == existing_user1 - assert result["app-789"] == existing_user2 - mock_session.add_all.assert_not_called() - mock_session.commit.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation with some existing and some new users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - # app-789 and app-123 don't exist - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 3 - assert result["app-456"] == existing_user1 - assert "app-789" in result - assert "app-123" in result - - # Should create 2 new users - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 2 - - mock_session.commit.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): - """Test batch creation handles duplicates in existing users gracefully.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Simulate duplicate records in database - existing_user1 = factory.create_end_user_mock( - user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - # Should prefer the first one found - assert result["app-456"] == existing_user1 - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): - """Test batch creation with all InvokeFrom types.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - added_user = mock_session.add_all.call_args[0][0][0] - assert added_user.type == invoke_type - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): - """Test batch creation with single app_id.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - assert "app-456" in result - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 1 - assert added_users[0].app_id == "app-456" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): - """Test batch creation correctly sets anonymous flag.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - - # Test with regular user ID - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - authenticated user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is False - - # Test with default session ID - mock_session.reset_mock() - mock_query.reset_mock() - mock_query.all.return_value = [] - - # Act - anonymous user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_ids=app_ids, - user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): - """Test that batch creation uses efficient single query for existing users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - # Should make exactly one query to check for existing users - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.all.assert_called_once() - - # Verify the where clause uses .in_() for app_ids - where_call = mock_query.where.call_args[0] - # The exact structure depends on SQLAlchemy implementation - # but we can verify it was called with the right parameters - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_session_context_manager(self, mock_db, mock_session_class): - """Test that batch creation properly uses session context manager.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_feedback_service.py b/api/tests/unit_tests/services/test_feedback_service.py deleted file mode 100644 index 1f70839ee2..0000000000 --- a/api/tests/unit_tests/services/test_feedback_service.py +++ /dev/null @@ -1,626 +0,0 @@ -import csv -import io -import json -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.feedback_service import FeedbackService - - -class TestFeedbackServiceFactory: - """Factory class for creating test data and mock objects for feedback service tests.""" - - @staticmethod - def create_feedback_mock( - feedback_id: str = "feedback-123", - app_id: str = "app-456", - conversation_id: str = "conv-789", - message_id: str = "msg-001", - rating: str = "like", - content: str | None = "Great response!", - from_source: str = "user", - from_account_id: str | None = None, - from_end_user_id: str | None = "end-user-001", - created_at: datetime | None = None, - ) -> MagicMock: - """Create a mock MessageFeedback object.""" - feedback = MagicMock() - feedback.id = feedback_id - feedback.app_id = app_id - feedback.conversation_id = conversation_id - feedback.message_id = message_id - feedback.rating = rating - feedback.content = content - feedback.from_source = from_source - feedback.from_account_id = from_account_id - feedback.from_end_user_id = from_end_user_id - feedback.created_at = created_at or datetime.now() - return feedback - - @staticmethod - def create_message_mock( - message_id: str = "msg-001", - query: str = "What is AI?", - answer: str = "AI stands for Artificial Intelligence.", - inputs: dict | None = None, - created_at: datetime | None = None, - ): - """Create a mock Message object.""" - - # Create a simple object with instance attributes - # Using a class with __init__ ensures attributes are instance attributes - class Message: - def __init__(self): - self.id = message_id - self.query = query - self.answer = answer - self.inputs = inputs - self.created_at = created_at or datetime.now() - - return Message() - - @staticmethod - def create_conversation_mock( - conversation_id: str = "conv-789", - name: str | None = "Test Conversation", - ) -> MagicMock: - """Create a mock Conversation object.""" - conversation = MagicMock() - conversation.id = conversation_id - conversation.name = name - return conversation - - @staticmethod - def create_app_mock( - app_id: str = "app-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock() - app.id = app_id - app.name = name - return app - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - name: str = "Test Admin", - ) -> MagicMock: - """Create a mock Account object.""" - account = MagicMock() - account.id = account_id - account.name = name - return account - - -class TestFeedbackService: - """ - Comprehensive unit tests for FeedbackService. - - This test suite covers: - - CSV and JSON export formats - - All filter combinations - - Edge cases and error handling - - Response validation - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestFeedbackServiceFactory() - - @pytest.fixture - def sample_feedback_data(self, factory): - """Create sample feedback data for testing.""" - feedback = factory.create_feedback_mock( - rating="like", - content="Excellent answer!", - from_source="user", - ) - message = factory.create_message_mock( - query="What is Python?", - answer="Python is a programming language.", - ) - conversation = factory.create_conversation_mock(name="Python Discussion") - app = factory.create_app_mock(name="AI Assistant") - account = factory.create_account_mock(name="Admin User") - - return [(feedback, message, conversation, app, account)] - - # Test 01: CSV Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_basic(self, mock_db, factory, sample_feedback_data): - """Test basic CSV export with single feedback record.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - assert response.mimetype == "text/csv" - assert "charset=utf-8-sig" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - assert "dify_feedback_export_app-456" in response.headers["Content-Disposition"] - - # Verify CSV content - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - - assert len(rows) == 1 - assert rows[0]["feedback_rating"] == "👍" - assert rows[0]["feedback_rating_raw"] == "like" - assert rows[0]["feedback_comment"] == "Excellent answer!" - assert rows[0]["user_query"] == "What is Python?" - assert rows[0]["ai_response"] == "Python is a programming language." - - # Test 02: JSON Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_json_basic(self, mock_db, factory, sample_feedback_data): - """Test basic JSON export with metadata structure.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - assert response.mimetype == "application/json" - assert "charset=utf-8" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - - # Verify JSON structure - json_content = json.loads(response.get_data(as_text=True)) - assert "export_info" in json_content - assert "feedback_data" in json_content - assert json_content["export_info"]["app_id"] == "app-456" - assert json_content["export_info"]["total_records"] == 1 - assert len(json_content["feedback_data"]) == 1 - - # Test 03: Filter by from_source - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_from_source(self, mock_db, factory): - """Test filtering by feedback source (user/admin).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", from_source="admin") - - # Assert - mock_query.filter.assert_called() - - # Test 04: Filter by rating - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_rating(self, mock_db, factory): - """Test filtering by rating (like/dislike).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", rating="dislike") - - # Assert - mock_query.filter.assert_called() - - # Test 05: Filter by has_comment (True) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_true(self, mock_db, factory): - """Test filtering for feedback with comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=True) - - # Assert - mock_query.filter.assert_called() - - # Test 06: Filter by has_comment (False) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_false(self, mock_db, factory): - """Test filtering for feedback without comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=False) - - # Assert - mock_query.filter.assert_called() - - # Test 07: Filter by date range - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_date_range(self, mock_db, factory): - """Test filtering by start and end dates.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - assert mock_query.filter.call_count >= 2 # Called for both start and end dates - - # Test 08: Invalid date format - start_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_start_date(self, mock_db): - """Test error handling for invalid start_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid start_date format"): - FeedbackService.export_feedbacks(app_id="app-456", start_date="invalid-date") - - # Test 09: Invalid date format - end_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_end_date(self, mock_db): - """Test error handling for invalid end_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid end_date format"): - FeedbackService.export_feedbacks(app_id="app-456", end_date="2024-13-45") - - # Test 10: Unsupported format - def test_export_feedbacks_unsupported_format(self): - """Test error handling for unsupported export format.""" - # Act & Assert - with pytest.raises(ValueError, match="Unsupported format"): - FeedbackService.export_feedbacks(app_id="app-456", format_type="xml") - - # Test 11: Empty result set - CSV - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_csv(self, mock_db): - """Test CSV export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - assert len(rows) == 0 - # But headers should still be present - assert reader.fieldnames is not None - - # Test 12: Empty result set - JSON - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_json(self, mock_db): - """Test JSON export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["export_info"]["total_records"] == 0 - assert len(json_content["feedback_data"]) == 0 - - # Test 13: Long response truncation - @patch("services.feedback_service.db") - def test_export_feedbacks_long_response_truncation(self, mock_db, factory): - """Test that long AI responses are truncated to 500 characters.""" - # Arrange - long_answer = "A" * 600 # 600 characters - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(answer=long_answer) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - ai_response = json_content["feedback_data"][0]["ai_response"] - assert len(ai_response) == 503 # 500 + "..." - assert ai_response.endswith("...") - - # Test 14: Null account (end user feedback) - @patch("services.feedback_service.db") - def test_export_feedbacks_null_account(self, mock_db, factory): - """Test handling of feedback from end users (no account).""" - # Arrange - feedback = factory.create_feedback_mock(from_account_id=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = None # No account for end user - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["from_account_name"] == "" - - # Test 15: Null conversation name - @patch("services.feedback_service.db") - def test_export_feedbacks_null_conversation_name(self, mock_db, factory): - """Test handling of conversations without names.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock() - conversation = factory.create_conversation_mock(name=None) - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["conversation_name"] == "" - - # Test 16: Dislike rating emoji - @patch("services.feedback_service.db") - def test_export_feedbacks_dislike_rating(self, mock_db, factory): - """Test that dislike rating shows thumbs down emoji.""" - # Arrange - feedback = factory.create_feedback_mock(rating="dislike") - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_rating"] == "👎" - assert json_content["feedback_data"][0]["feedback_rating_raw"] == "dislike" - - # Test 17: Combined filters - @patch("services.feedback_service.db") - def test_export_feedbacks_combined_filters(self, mock_db, factory): - """Test applying multiple filters simultaneously.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - from_source="admin", - rating="like", - has_comment=True, - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - # Should have called filter multiple times for each condition - assert mock_query.filter.call_count >= 4 - - # Test 18: Message query fallback to inputs - @patch("services.feedback_service.db") - def test_export_feedbacks_message_query_from_inputs(self, mock_db, factory): - """Test fallback to inputs.query when message.query is None.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(query=None, inputs={"query": "Query from inputs"}) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["user_query"] == "Query from inputs" - - # Test 19: Empty feedback content - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_feedback_content(self, mock_db, factory): - """Test handling of feedback with empty/null content.""" - # Arrange - feedback = factory.create_feedback_mock(content=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_comment"] == "" - assert json_content["feedback_data"][0]["has_comment"] == "No" - - # Test 20: CSV headers validation - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_headers(self, mock_db, factory, sample_feedback_data): - """Test that CSV contains all expected headers.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - expected_headers = [ - "feedback_id", - "app_name", - "app_id", - "conversation_id", - "conversation_name", - "message_id", - "user_query", - "ai_response", - "feedback_rating", - "feedback_rating_raw", - "feedback_comment", - "feedback_source", - "feedback_date", - "message_date", - "from_account_name", - "from_end_user_id", - "has_comment", - ] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - assert list(reader.fieldnames) == expected_headers diff --git a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py b/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py deleted file mode 100644 index 7b4d349e33..0000000000 --- a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Unit tests for `services.file_service.FileService` helpers. - -We keep these tests focused on: -- ZIP tempfile building (sanitization + deduplication + content writes) -- tenant-scoped batch lookup behavior (`get_upload_files_by_ids`) -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any -from zipfile import ZipFile - -import pytest - -import services.file_service as file_service_module -from services.file_service import FileService - - -def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure ZIP entry names are safe and unique while preserving extensions.""" - - # Arrange: three upload files that all sanitize down to the same basename ("b.txt"). - upload_files: list[Any] = [ - SimpleNamespace(name="a/b.txt", key="k1"), - SimpleNamespace(name="c/b.txt", key="k2"), - SimpleNamespace(name="../b.txt", key="k3"), - ] - - # Stream distinct bytes per key so we can verify content is written to the right entry. - data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} - - def _load(key: str, stream: bool = True) -> list[bytes]: - # Return the corresponding chunks for this key (the production code iterates chunks). - assert stream is True - return data_by_key[key] - - monkeypatch.setattr(file_service_module.storage, "load", _load) - - # Act: build zip in a tempfile. - with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: - with ZipFile(tmp, mode="r") as zf: - # Assert: names are sanitized (no directory components) and deduped with suffixes. - assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] - - # Assert: each entry contains the correct bytes from storage. - assert zf.read("b.txt") == b"one" - assert zf.read("b (1).txt") == b"two" - assert zf.read("b (2).txt") == b"three" - - -def test_get_upload_files_by_ids_returns_empty_when_no_ids(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure empty input returns an empty mapping without hitting the database.""" - - class _Session: - def scalars(self, _stmt): # type: ignore[no-untyped-def] - raise AssertionError("db.session.scalars should not be called for empty id lists") - - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=_Session())) - - assert FileService.get_upload_files_by_ids("tenant-1", []) == {} - - -def test_get_upload_files_by_ids_returns_id_keyed_mapping(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" - - upload_files: list[Any] = [ - SimpleNamespace(id="file-1", tenant_id="tenant-1"), - SimpleNamespace(id="file-2", tenant_id="tenant-1"), - ] - - class _ScalarResult: - def __init__(self, items: list[Any]) -> None: - self._items = items - - def all(self) -> list[Any]: - return self._items - - class _Session: - def __init__(self, items: list[Any]) -> None: - self._items = items - self.calls: list[object] = [] - - def scalars(self, stmt): # type: ignore[no-untyped-def] - # Capture the statement so we can at least assert the query path is taken. - self.calls.append(stmt) - return _ScalarResult(self._items) - - session = _Session(upload_files) - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=session)) - - # Provide duplicates to ensure callers can safely pass repeated ids. - result = FileService.get_upload_files_by_ids("tenant-1", ["file-1", "file-1", "file-2"]) - - assert set(result.keys()) == {"file-1", "file-2"} - assert result["file-1"].id == "file-1" - assert result["file-2"].id == "file-2" - assert len(session.calls) == 1 diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 375e47d7fc..55af564821 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -9,12 +9,13 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( +from graphon.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) -from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( Form, @@ -51,11 +52,11 @@ def sample_form_record(): inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow() + timedelta(hours=1), + expiration_time=naive_utc_now() + timedelta(hours=1), ), rendered_content="

hello

", - created_at=datetime.utcnow(), - expiration_time=datetime.utcnow() + timedelta(hours=1), + created_at=naive_utc_now(), + expiration_time=naive_utc_now() + timedelta(hours=1), status=HumanInputFormStatus.WAITING, selected_action_id=None, submitted_data=None, @@ -101,8 +102,8 @@ def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_rec service = HumanInputService(session_factory) expired_record = dataclasses.replace( sample_form_record, - created_at=datetime.utcnow() - timedelta(hours=2), - expiration_time=datetime.utcnow() + timedelta(hours=2), + created_at=naive_utc_now() - timedelta(hours=2), + expiration_time=naive_utc_now() + timedelta(hours=2), ) monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) @@ -391,7 +392,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): service = HumanInputService(session_factory) # Submitted - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service.ensure_form_active(Form(submitted_record)) @@ -402,7 +403,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): # Expired time expired_time_record = dataclasses.replace( - sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + sample_form_record, expiration_time=naive_utc_now() - timedelta(minutes=1) ) with pytest.raises(FormExpiredError): service.ensure_form_active(Form(expired_time_record)) @@ -411,7 +412,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): session_factory, _ = mock_session_factory service = HumanInputService(session_factory) - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service._ensure_not_submitted(Form(submitted_record)) diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py index bc0caee071..53c243ad71 100644 --- a/api/tests/unit_tests/services/test_knowledge_service.py +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService class TestKnowledgeService: @@ -24,7 +24,7 @@ class TestKnowledgeService: mock_client = MagicMock() mock_boto_client.return_value = mock_client - retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + retrieval_setting = BedrockRetrievalSetting(top_k=4, score_threshold=0.5) query = "test query" knowledge_id = "kb-123" @@ -87,7 +87,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -104,7 +107,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -114,7 +120,7 @@ class TestKnowledgeService: with patch("services.knowledge_service.boto3.client") as mock_boto: mock_boto.side_effect = Exception("client init failed") with pytest.raises(Exception) as exc_info: - ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb") assert "client init failed" in str(exc_info.value) # ===== Edge Cases ===== @@ -139,7 +145,10 @@ class TestKnowledgeService: # Act # retrieval_setting missing "score_threshold" - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert len(result["records"]) == 1 diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index e7740ef93a..101b9bff24 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -933,7 +933,7 @@ class TestMessageServiceSuggestedQuestions: ) # Test 28: get_suggested_questions_after_answer - Advanced Chat success - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.WorkflowService") @patch("services.message_service.AdvancedChatAppConfigManager") @patch("services.message_service.TokenBufferMemory") @@ -983,7 +983,7 @@ class TestMessageServiceSuggestedQuestions: # Test 29: get_suggested_questions_after_answer - Chat app success (no override) @patch("services.message_service.db") - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.TokenBufferMemory") @patch("services.message_service.LLMGenerator") @patch("services.message_service.TraceQueueManager") diff --git a/api/tests/unit_tests/services/test_metadata_partial_update.py b/api/tests/unit_tests/services/test_metadata_partial_update.py deleted file mode 100644 index 60252784bc..0000000000 --- a/api/tests/unit_tests/services/test_metadata_partial_update.py +++ /dev/null @@ -1,187 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.entities.knowledge_entities.knowledge_entities import ( - DocumentMetadataOperation, - MetadataDetail, - MetadataOperationData, -) -from services.metadata_service import MetadataService - - -class TestMetadataPartialUpdate(unittest.TestCase): - def setUp(self): - self.dataset = MagicMock(spec=Dataset) - self.dataset.id = "dataset_id" - self.dataset.built_in_field_enabled = False - - self.document = MagicMock(spec=Document) - self.document.id = "doc_id" - self.document.doc_metadata = {"existing_key": "existing_value"} - self.document.data_source_type = "upload_file" - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_partial_update_merges_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Mock DB query for existing bindings - - # No existing binding for new key - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Input data - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # 1. Check that doc_metadata contains BOTH existing and new keys - expected_metadata = {"existing_key": "existing_value", "new_key": "new_value"} - assert self.document.doc_metadata == expected_metadata - - # 2. Check that existing bindings were NOT deleted - # The delete call in the original code: db.session.query(...).filter_by(...).delete() - # In partial update, this should NOT be called. - mock_db.session.query.return_value.filter_by.return_value.delete.assert_not_called() - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_full_update_replaces_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Input data (partial_update=False by default) - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")], - partial_update=False, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # 1. Check that doc_metadata contains ONLY the new key - expected_metadata = {"new_key": "new_value"} - assert self.document.doc_metadata == expected_metadata - - # 2. Check that existing bindings WERE deleted - # In full update (default), we expect the existing bindings to be cleared. - mock_db.session.query.return_value.filter_by.return_value.delete.assert_called() - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_partial_update_skips_existing_binding( - self, mock_redis, mock_current_account, mock_document_service, mock_db - ): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Mock DB query to return an existing binding - # This simulates that the document ALREADY has the metadata we are trying to add - mock_existing_binding = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_existing_binding - - # Input data - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="existing_meta_id", name="existing_key", value="existing_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # We verify that db.session.add was NOT called for DatasetMetadataBinding - # Since we can't easily check "not called with specific type" on the generic add method without complex logic, - # we can check if the number of add calls is 1 (only for the document update) instead of 2 (document + binding) - - # Expected calls: - # 1. db.session.add(document) - # 2. NO db.session.add(binding) because it exists - - # Note: In the code, db.session.add is called for document. - # Then loop over metadata_list. - # If existing_binding found, continue. - # So binding add should be skipped. - - # Let's filter the calls to add to see what was added - add_calls = mock_db.session.add.call_args_list - added_objects = [call.args[0] for call in add_calls] - - # Check that no DatasetMetadataBinding was added - from models.dataset import DatasetMetadataBinding - - has_binding_add = any( - isinstance(obj, DatasetMetadataBinding) - or (isinstance(obj, MagicMock) and getattr(obj, "__class__", None) == DatasetMetadataBinding) - for obj in added_objects - ) - - # Since we mock everything, checking isinstance might be tricky if DatasetMetadataBinding - # is not the exact class used in the service (imports match). - # But we can check the count. - # If it were added, there would be 2 calls. If skipped, 1 call. - assert mock_db.session.add.call_count == 1 - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_rollback_called_on_commit_failure(self, mock_redis, mock_current_account, mock_document_service, mock_db): - """When db.session.commit() raises, rollback must be called and the exception must propagate.""" - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Make commit raise an exception - mock_db.session.commit.side_effect = RuntimeError("database connection lost") - - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="meta_id", name="key", value="value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act & Assert: the exception must propagate - with pytest.raises(RuntimeError, match="database connection lost"): - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify rollback was called - mock_db.session.rollback.assert_called_once() - - # Verify the lock key was cleaned up despite the failure - mock_redis.delete.assert_called_with("document_metadata_lock_doc_id") - - -if __name__ == "__main__": - unittest.main() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index 49e572584b..1e898ada11 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -9,9 +9,9 @@ import pytest from pytest_mock import MockerFixture from constants import HIDDEN_VALUE -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( CredentialFormSchema, FieldModelSchema, FormType, @@ -69,9 +69,13 @@ def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: def service(mocker: MockerFixture) -> ModelLoadBalancingService: # Arrange provider_manager = MagicMock() - mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) + mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager) + model_assembly = SimpleNamespace(provider_manager=provider_manager, model_provider_factory=MagicMock()) + mocker.patch("services.model_load_balancing_service.create_plugin_model_assembly", return_value=model_assembly) svc = ModelLoadBalancingService() svc.provider_manager = provider_manager + svc.model_assembly = model_assembly + svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign] return svc @@ -666,6 +670,9 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_ assert mock_validate.call_count == 2 assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + shared_model_provider_factory = service.model_assembly.model_provider_factory + assert mock_validate.call_args_list[0].kwargs["model_provider_factory"] is shared_model_provider_factory + assert mock_validate.call_args_list[1].kwargs["model_provider_factory"] is shared_model_provider_factory def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( @@ -708,7 +715,6 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") mock_factory = MagicMock() mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} - mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) mock_encrypt = mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -722,6 +728,7 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val model="gpt-4o-mini", credentials={"api_key": "plain"}, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=mock_factory, validate=True, ) @@ -740,7 +747,6 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) mock_factory = MagicMock() mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} - mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -753,6 +759,7 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m model_type=ModelType.LLM, model="gpt-4o-mini", credentials={"api_key": "plain"}, + model_provider_factory=mock_factory, validate=True, ) diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 6a6b63f003..97f3bd6f01 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -3,9 +3,9 @@ import types import pytest from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod from models.provider import ProviderType from services.model_provider_service import ModelProviderService @@ -71,7 +71,7 @@ def service_with_fake_configurations(): return _FakeConfigurations(fake_provider_configuration) svc = ModelProviderService() - svc.provider_manager = _FakeProviderManager() + svc._get_provider_manager = lambda tenant_id: _FakeProviderManager() return svc diff --git a/api/tests/unit_tests/services/test_oauth_server_service.py b/api/tests/unit_tests/services/test_oauth_server_service.py deleted file mode 100644 index 231ceb74dc..0000000000 --- a/api/tests/unit_tests/services/test_oauth_server_service.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import BadRequest - -from services.oauth_server import ( - OAUTH_ACCESS_TOKEN_EXPIRES_IN, - OAUTH_ACCESS_TOKEN_REDIS_KEY, - OAUTH_AUTHORIZATION_CODE_REDIS_KEY, - OAUTH_REFRESH_TOKEN_EXPIRES_IN, - OAUTH_REFRESH_TOKEN_REDIS_KEY, - OAuthGrantType, - OAuthServerService, -) - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.oauth_server.redis_client") - - -@pytest.fixture -def mock_session(mocker: MockerFixture) -> MagicMock: - """Mock the OAuth server Session context manager.""" - mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object())) - session = MagicMock() - session_cm = MagicMock() - session_cm.__enter__.return_value = session - mocker.patch("services.oauth_server.Session", return_value=session_cm) - return session - - -def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None: - # Arrange - mock_execute_result = MagicMock() - expected_app = MagicMock() - mock_execute_result.scalar_one_or_none.return_value = expected_app - mock_session.execute.return_value = mock_execute_result - - # Act - result = OAuthServerService.get_oauth_provider_app("client-1") - - # Assert - assert result is expected_app - mock_session.execute.assert_called_once() - mock_execute_result.scalar_one_or_none.assert_called_once() - - -def test_sign_oauth_authorization_code_should_store_code_and_return_value( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") - - # Assert - expected_code = str(deterministic_uuid) - assert code == expected_code - mock_redis_client.set.assert_called_once_with( - OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code), - "user-1", - ex=600, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid code"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="bad-code", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - token_uuids = [ - uuid.UUID("00000000-0000-0000-0000-000000000201"), - uuid.UUID("00000000-0000-0000-0000-000000000202"), - ] - mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids) - mock_redis_client.get.return_value = b"user-1" - code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") - - # Act - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="code-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(token_uuids[0]) - assert refresh_token == str(token_uuids[1]) - mock_redis_client.delete.assert_called_once_with(code_key) - mock_redis_client.set.assert_any_call( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - mock_redis_client.set.assert_any_call( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), - b"user-1", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid refresh token"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="stale-token", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - mock_redis_client.get.return_value = b"user-1" - - # Act - access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="refresh-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(deterministic_uuid) - assert returned_refresh_token == "refresh-1" - mock_redis_client.set.assert_called_once_with( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None: - # Arrange - grant_type = cast(OAuthGrantType, "invalid-grant-type") - - # Act - result = OAuthServerService.sign_oauth_access_token( - grant_type=grant_type, - client_id="client-1", - ) - - # Assert - assert result is None - - -def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") - - # Assert - assert refresh_token == str(deterministic_uuid) - mock_redis_client.set.assert_called_once_with( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), - "user-2", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_validate_oauth_access_token_should_return_none_when_token_not_found( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") - - # Assert - assert result is None - - -def test_validate_oauth_access_token_should_load_user_when_token_exists( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - mock_redis_client.get.return_value = b"user-88" - expected_user = MagicMock() - mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user) - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") - - # Assert - assert result is expected_user - mock_load_user.assert_called_once_with("user-88") diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py deleted file mode 100644 index a214ecf728..0000000000 --- a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Unit tests for workflow run restore functionality. -""" - -from datetime import datetime - - -class TestWorkflowRunRestore: - """Tests for the WorkflowRunRestore class.""" - - def test_restore_initialization(self): - """Restore service should respect dry_run flag.""" - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - restore = WorkflowRunRestore(dry_run=True) - - assert restore.dry_run is True - - def test_convert_datetime_fields(self): - """ISO datetime strings should be converted to datetime objects.""" - from models.workflow import WorkflowRun - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - record = { - "id": "test-id", - "created_at": "2024-01-01T12:00:00", - "finished_at": "2024-01-01T12:05:00", - "name": "test", - } - - restore = WorkflowRunRestore() - result = restore._convert_datetime_fields(record, WorkflowRun) - - assert isinstance(result["created_at"], datetime) - assert result["created_at"].year == 2024 - assert result["created_at"].month == 1 - assert result["name"] == "test" diff --git a/api/tests/unit_tests/services/test_saved_message_service.py b/api/tests/unit_tests/services/test_saved_message_service.py deleted file mode 100644 index 87b946fe46..0000000000 --- a/api/tests/unit_tests/services/test_saved_message_service.py +++ /dev/null @@ -1,626 +0,0 @@ -""" -Comprehensive unit tests for SavedMessageService. - -This test suite provides complete coverage of saved message operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -## Test Coverage - -### 1. Pagination (TestSavedMessageServicePagination) -Tests saved message listing and pagination: -- Pagination with valid user (Account and EndUser) -- Pagination without user raises ValueError -- Pagination with last_id parameter -- Empty results when no saved messages exist -- Integration with MessageService pagination - -### 2. Save Operations (TestSavedMessageServiceSave) -Tests saving messages: -- Save message for Account user -- Save message for EndUser -- Save without user (no-op) -- Prevent duplicate saves (idempotent) -- Message validation through MessageService - -### 3. Delete Operations (TestSavedMessageServiceDelete) -Tests deleting saved messages: -- Delete saved message for Account user -- Delete saved message for EndUser -- Delete without user (no-op) -- Delete non-existent saved message (no-op) -- Proper database cleanup - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked - for fast, isolated unit tests -- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**User Types:** -- Account: Workspace members (console users) -- EndUser: API users (end users) - -**Saved Messages:** -- Users can save messages for later reference -- Each user has their own saved message list -- Saving is idempotent (duplicate saves ignored) -- Deletion is safe (non-existent deletes ignored) -""" - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest - -from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models import Account -from models.model import App, EndUser, Message -from models.web import SavedMessage -from services.saved_message_service import SavedMessageService - - -class SavedMessageServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - saved message operations. - """ - - @staticmethod - def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock: - """ - Create a mock Account object. - - Args: - account_id: Unique identifier for the account - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Account object with specified attributes - """ - account = create_autospec(Account, instance=True) - account.id = account_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock: - """ - Create a mock EndUser object. - - Args: - user_id: Unique identifier for the end user - **kwargs: Additional attributes to set on the mock - - Returns: - Mock EndUser object with specified attributes - """ - user = create_autospec(EndUser, instance=True) - user.id = user_id - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant/workspace identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - """ - app = create_autospec(App, instance=True) - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - app.mode = kwargs.get("mode", "chat") - for key, value in kwargs.items(): - setattr(app, key, value) - return app - - @staticmethod - def create_message_mock( - message_id: str = "msg-123", - app_id: str = "app-123", - **kwargs, - ) -> Mock: - """ - Create a mock Message object. - - Args: - message_id: Unique identifier for the message - app_id: Associated app identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Message object with specified attributes - """ - message = create_autospec(Message, instance=True) - message.id = message_id - message.app_id = app_id - message.query = kwargs.get("query", "Test query") - message.answer = kwargs.get("answer", "Test answer") - message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(message, key, value) - return message - - @staticmethod - def create_saved_message_mock( - saved_message_id: str = "saved-123", - app_id: str = "app-123", - message_id: str = "msg-123", - created_by: str = "user-123", - created_by_role: str = "account", - **kwargs, - ) -> Mock: - """ - Create a mock SavedMessage object. - - Args: - saved_message_id: Unique identifier for the saved message - app_id: Associated app identifier - message_id: Associated message identifier - created_by: User who saved the message - created_by_role: Role of the user ('account' or 'end_user') - **kwargs: Additional attributes to set on the mock - - Returns: - Mock SavedMessage object with specified attributes - """ - saved_message = create_autospec(SavedMessage, instance=True) - saved_message.id = saved_message_id - saved_message.app_id = app_id - saved_message.message_id = message_id - saved_message.created_by = created_by - saved_message.created_by_role = created_by_role - saved_message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(saved_message, key, value) - return saved_message - - -@pytest.fixture -def factory(): - """Provide the test data factory to all tests.""" - return SavedMessageServiceTestDataFactory - - -class TestSavedMessageServicePagination: - """Test saved message pagination operations.""" - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Create saved messages for this user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="account", - ) - for i in range(3) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - mock_db_session.query.assert_called_once_with(SavedMessage) - # Verify MessageService was called with correct message IDs - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=["msg-0", "msg-1", "msg-2"], - ) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - - # Create saved messages for this end user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="end_user", - ) - for i in range(2) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10) - - # Assert - assert result == expected_pagination - # Verify correct role was used in query - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=10, - include_ids=["msg-0", "msg-1"], - ) - - def test_pagination_without_user_raises_error(self, factory): - """Test that pagination without user raises ValueError.""" - # Arrange - app = factory.create_app_mock() - - # Act & Assert - with pytest.raises(ValueError, match="User is required"): - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with last_id parameter.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - last_id = "msg-last" - - saved_messages = [ - factory.create_saved_message_mock( - message_id=f"msg-{i}", - app_id=app.id, - created_by=user.id, - ) - for i in range(5) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10) - - # Assert - assert result == expected_pagination - # Verify last_id was passed to MessageService - mock_message_pagination.assert_called_once() - call_args = mock_message_pagination.call_args - assert call_args.kwargs["last_id"] == last_id - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory): - """Test pagination when user has no saved messages.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Mock database query returning empty list - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - # Verify MessageService was called with empty include_ids - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=[], - ) - - -class TestSavedMessageServiceSave: - """Test save message operations.""" - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_account(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock(message_id="msg-123", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "account" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message = factory.create_message_mock(message_id="msg-456", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "end_user" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_without_user_does_nothing(self, mock_db_session, factory): - """Test that saving without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.save(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory): - """Test that saving an already saved message is idempotent.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-789" - - # Mock database query - existing saved message found - existing_saved = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_saved - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message_id) - - # Assert - no new saved message created - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - mock_get_message.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory): - """Test that save validates message exists through MessageService.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock() - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - MessageService.get_message was called for validation - mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id) - - -class TestSavedMessageServiceDelete: - """Test delete saved message operations.""" - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_account(self, mock_db_session, factory): - """Test deleting a saved message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-123" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_end_user(self, mock_db_session, factory): - """Test deleting a saved message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message_id = "msg-456" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="end_user", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_without_user_does_nothing(self, mock_db_session, factory): - """Test that deleting without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.delete(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory): - """Test that deleting a non-existent saved message is a no-op.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-nonexistent" - - # Mock database query - no saved message found - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - no deletion occurred - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory): - """Test that delete only removes the user's own saved message.""" - # Arrange - app = factory.create_app_mock() - user1 = factory.create_account_mock(account_id="user-1") - message_id = "msg-shared" - - # Mock database query - finds user1's saved message - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user1.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user1, message_id=message_id) - - # Assert - only user1's saved message is deleted - mock_db_session.delete.assert_called_once_with(saved_message) - # Verify the query filters by user - assert mock_query.where.called diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index be64e431ba..cbf3e121d8 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -26,7 +27,7 @@ class _SessionContext: return None -def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: +def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock: dataset = MagicMock(name="dataset") dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" @@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: if has_document: doc = MagicMock(name="document") doc.doc_language = "en" - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX segment.document = doc else: segment.document = None @@ -168,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: vector_cls = MagicMock() monkeypatch.setattr(summary_module, "Vector", vector_cls) - SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset) vector_cls.assert_not_called() @@ -189,7 +191,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: embedding_model.get_text_embedding_num_tokens.return_value = [5] model_manager = MagicMock() model_manager.get_model_instance.return_value = embedding_model - monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager)) vector_instance = MagicMock() vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] @@ -228,7 +230,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat model_manager = MagicMock() model_manager.get_model_instance.side_effect = RuntimeError("no model") - monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager)) # New session used after vectorization succeeds (record not found by id nor chunk_id). session = MagicMock(name="session") @@ -405,8 +407,8 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch vector_instance.add_texts.return_value = None monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -439,8 +441,8 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -472,8 +474,8 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -508,8 +510,8 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -620,16 +622,16 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _dataset(indexing_technique="economy") + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] dataset = _dataset() assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] @@ -637,7 +639,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg1 = _segment() seg2 = _segment() @@ -673,7 +675,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() query = MagicMock() @@ -696,7 +698,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg = _segment() session = MagicMock() @@ -777,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo def test_enable_summaries_for_segments_skips_non_high_quality() -> None: - SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY)) def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: @@ -931,11 +933,10 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon def test_update_summary_for_segment_skip_conditions() -> None: - assert ( - SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None - ) + economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None seg = _segment(has_document=True) - seg.document.doc_form = "qa_model" + seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 264eac4d77..b09463b1bc 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -75,6 +75,7 @@ import pytest from werkzeug.exceptions import NotFound from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TagServiceTestDataFactory: def create_tag_mock( tag_id: str = "tag-123", name: str = "Test Tag", - tag_type: str = "app", + tag_type: TagType = TagType.APP, tenant_id: str = "tenant-123", **kwargs, ) -> Mock: @@ -315,7 +316,7 @@ class TestTagServiceRetrieval: - get_tags_by_target_id: Get all tags bound to a specific target """ - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_with_binding_counts(self, mock_db_session, factory): """ Test retrieving tags with their binding counts. @@ -372,7 +373,7 @@ class TestTagServiceRetrieval: # Verify database query was called mock_db_session.query.assert_called_once() - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_with_keyword_filter(self, mock_db_session, factory): """ Test retrieving tags filtered by keyword (case-insensitive). @@ -426,7 +427,7 @@ class TestTagServiceRetrieval: # 2. Additional WHERE clause for keyword filtering assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): """ Test retrieving target IDs by tag IDs. @@ -482,7 +483,7 @@ class TestTagServiceRetrieval: # Verify both queries were executed assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): """ Test that empty tag_ids returns empty list. @@ -510,7 +511,7 @@ class TestTagServiceRetrieval: assert results == [], "Should return empty list for empty input" mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_by_tag_name(self, mock_db_session, factory): """ Test retrieving tags by name. @@ -552,7 +553,7 @@ class TestTagServiceRetrieval: assert len(results) == 1, "Should find exactly one tag" assert results[0].name == tag_name, "Tag name should match" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): """ Test that missing tag_type or tag_name returns empty list. @@ -580,7 +581,7 @@ class TestTagServiceRetrieval: # Verify no database queries were executed mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tags_by_target_id(self, mock_db_session, factory): """ Test retrieving tags associated with a specific target. @@ -653,7 +654,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") @patch("services.tag_service.uuid.uuid4", autospec=True) def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ @@ -705,7 +706,7 @@ class TestTagServiceCRUD: # Verify tag attributes added_tag = mock_db_session.add.call_args[0][0] assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == "app", "Tag type should match" + assert added_tag.type == TagType.APP, "Tag type should match" assert added_tag.created_by == "user-123", "Created by should match current user" assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" @@ -742,7 +743,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ Test updating a tag name. @@ -794,7 +795,7 @@ class TestTagServiceCRUD: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags_raises_error_for_duplicate_name( self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory ): @@ -826,7 +827,7 @@ class TestTagServiceCRUD: with pytest.raises(ValueError, match="Tag name already exists"): TagService.update_tags(args, tag_id="tag-123") - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): """ Test that updating a non-existent tag raises NotFound. @@ -858,7 +859,7 @@ class TestTagServiceCRUD: with pytest.raises(NotFound, match="Tag not found"): TagService.update_tags(args, tag_id="nonexistent") - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_get_tag_binding_count(self, mock_db_session, factory): """ Test getting the count of bindings for a tag. @@ -894,7 +895,7 @@ class TestTagServiceCRUD: # Verify count matches expectation assert result == expected_count, "Binding count should match" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag(self, mock_db_session, factory): """ Test deleting a tag and its bindings. @@ -950,7 +951,7 @@ class TestTagServiceCRUD: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_raises_not_found(self, mock_db_session, factory): """ Test that deleting a non-existent tag raises NotFound. @@ -998,7 +999,7 @@ class TestTagServiceBindings: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test creating tag bindings. @@ -1049,7 +1050,7 @@ class TestTagServiceBindings: @patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test that saving duplicate bindings is idempotent. @@ -1089,7 +1090,7 @@ class TestTagServiceBindings: mock_db_session.add.assert_not_called(), "Should not create duplicate binding" @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): """ Test deleting a tag binding. @@ -1137,7 +1138,7 @@ class TestTagServiceBindings: mock_db_session.commit.assert_called_once(), "Should commit transaction" @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): """ Test that deleting a non-existent binding is a no-op. @@ -1174,7 +1175,7 @@ class TestTagServiceBindings: mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): """ Test validating that a dataset target exists. @@ -1215,7 +1216,7 @@ class TestTagServiceBindings: mock_db_session.query.assert_called_once(), "Should query database for dataset" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): """ Test validating that an app target exists. @@ -1256,7 +1257,7 @@ class TestTagServiceBindings: mock_db_session.query.assert_called_once(), "Should query database for app" @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_raises_not_found_for_missing_dataset( self, mock_db_session, mock_current_user, factory ): @@ -1288,7 +1289,7 @@ class TestTagServiceBindings: TagService.check_target_exists("knowledge", "nonexistent") @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.db.session") def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): """ Test that missing app raises NotFound. diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index c703ab64d0..2fe6161785 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -17,9 +17,9 @@ from uuid import uuid4 import pytest -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segments import ( +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArraySegment, diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 7b0103a2a1..598ff3fc3a 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from services.vector_service import VectorService @@ -31,8 +32,8 @@ class _ParentDocStub: def _make_dataset( *, - indexing_technique: str = "high_quality", - doc_form: str = "text_model", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", is_multimodal: bool = False, @@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_called_once() args, kwargs = index_processor.load.call_args @@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) assert index_processor.load.call_count == 2 first_args, first_kwargs = index_processor.load.call_args_list[0] @@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_not_called() @@ -191,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider="openai", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -213,7 +214,9 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex embedding_model_instance = MagicMock(name="embedding_model_instance") model_manager_instance = MagicMock(name="model_manager_instance") model_manager_instance.get_model_instance.return_value = embedding_model_instance - monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + monkeypatch.setattr( + vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance) + ) generate_child_chunks_mock = MagicMock() monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) @@ -240,7 +243,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -261,7 +264,9 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p embedding_model_instance = MagicMock() model_manager_instance = MagicMock() model_manager_instance.get_default_model_instance.return_value = embedding_model_instance - monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + monkeypatch.setattr( + vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance) + ) generate_child_chunks_mock = MagicMock() monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) @@ -328,7 +333,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) segment = _make_segment() dataset_document = MagicMock() @@ -347,7 +352,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = _make_segment() vector_instance = MagicMock() @@ -363,7 +368,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -379,7 +384,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -392,7 +397,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1") segment = _make_segment(segment_id="seg-1") dataset_document = MagicMock() @@ -439,7 +444,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX) segment = _make_segment() dataset_document = MagicMock() dataset_document.doc_language = "en" @@ -472,7 +477,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = MagicMock() child_chunk.content = "child" child_chunk.index_node_id = "id" @@ -488,7 +493,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) @@ -504,7 +509,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = MagicMock() new_chunk.content = "n" @@ -535,7 +540,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) VectorService.update_child_chunk_vector([], [], [], dataset) @@ -560,7 +565,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) vector_cls = MagicMock() @@ -574,7 +579,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) vector_cls = MagicMock() @@ -590,7 +595,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) vector_instance = MagicMock(name="vector_instance") @@ -611,7 +616,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -629,7 +634,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -662,7 +667,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -682,7 +687,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() diff --git a/api/tests/unit_tests/services/test_web_conversation_service.py b/api/tests/unit_tests/services/test_web_conversation_service.py deleted file mode 100644 index 7687d355e9..0000000000 --- a/api/tests/unit_tests/services/test_web_conversation_service.py +++ /dev/null @@ -1,259 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account -from models.model import App, EndUser -from services.web_conversation_service import WebConversationService - - -@pytest.fixture -def app_model() -> App: - return cast(App, SimpleNamespace(id="app-1")) - - -def _account(**kwargs: Any) -> Account: - return cast(Account, SimpleNamespace(**kwargs)) - - -def _end_user(**kwargs: Any) -> EndUser: - return cast(EndUser, SimpleNamespace(**kwargs)) - - -def test_pagination_by_last_id_should_raise_error_when_user_is_none( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - - # Act + Assert - with pytest.raises(ValueError, match="User is required"): - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=None, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - -def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_user = _account(id="user-1") - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id="conv-9", - limit=10, - invoke_from=InvokeFrom.WEB_APP, - pinned=None, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] is None - assert call_kwargs["exclude_ids"] is None - assert call_kwargs["last_id"] == "conv-9" - assert call_kwargs["sort_by"] == "-updated_at" - - -def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-1" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {})) - session.scalars.return_value.all.return_value = ["conv-1", "conv-2"] - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - pinned=True, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] == ["conv-1", "conv-2"] - assert call_kwargs["exclude_ids"] is None - - -def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-1" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - session.scalars.return_value.all.return_value = ["conv-3"] - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - pinned=False, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] is None - assert call_kwargs["exclude_ids"] == ["conv-3"] - - -def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: - # Arrange - mock_db = mocker.patch("services.web_conversation_service.db") - mocker.patch("services.web_conversation_service.ConversationService.get_conversation") - - # Act - WebConversationService.pin(app_model, "conv-1", None) - - # Assert - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_pin_should_return_early_when_conversation_is_already_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-1" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = object() - mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation") - - # Act - WebConversationService.pin(app_model, "conv-1", fake_user) - - # Assert - mock_get_conversation.assert_not_called() - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_pin_should_create_pinned_conversation_when_not_already_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-2" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = None - mock_conversation = SimpleNamespace(id="conv-2") - mock_get_conversation = mocker.patch( - "services.web_conversation_service.ConversationService.get_conversation", - return_value=mock_conversation, - ) - - # Act - WebConversationService.pin(app_model, "conv-2", fake_user) - - # Assert - mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user) - added_obj = mock_db.session.add.call_args.args[0] - assert added_obj.app_id == "app-1" - assert added_obj.conversation_id == "conv-2" - assert added_obj.created_by_role == "account" - assert added_obj.created_by == "account-2" - mock_db.session.commit.assert_called_once() - - -def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: - # Arrange - mock_db = mocker.patch("services.web_conversation_service.db") - - # Act - WebConversationService.unpin(app_model, "conv-1", None) - - # Assert - mock_db.session.delete.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_unpin_should_return_early_when_conversation_is_not_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-3" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act - WebConversationService.unpin(app_model, "conv-7", fake_user) - - # Assert - mock_db.session.delete.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_unpin_should_delete_pinned_conversation_when_exists( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-4" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - pinned_obj = SimpleNamespace(id="pin-1") - mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj - - # Act - WebConversationService.unpin(app_model, "conv-8", fake_user) - - # Assert - mock_db.session.delete.assert_called_once_with(pinned_obj) - mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_webapp_auth_service.py b/api/tests/unit_tests/services/test_webapp_auth_service.py deleted file mode 100644 index 262c1f1524..0000000000 --- a/api/tests/unit_tests/services/test_webapp_auth_service.py +++ /dev/null @@ -1,379 +0,0 @@ -from __future__ import annotations - -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import NotFound, Unauthorized - -from models import Account, AccountStatus -from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError -from services.webapp_auth_service import WebAppAuthService, WebAppAuthType - -ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback" -TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token" -TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data" - - -def _account(**kwargs: Any) -> Account: - return cast(Account, SimpleNamespace(**kwargs)) - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.webapp_auth_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None: - # Arrange - mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) - - # Act + Assert - with pytest.raises(AccountNotFoundError): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act + Assert - with pytest.raises(AccountLoginError, match="Account is banned"): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -@pytest.mark.parametrize("password_value", [None, "hash"]) -def test_authenticate_should_raise_password_error_when_password_is_invalid( - password_value: str | None, - mocker: MockerFixture, -) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - mocker.patch("services.webapp_auth_service.compare_password", return_value=False) - - # Act + Assert - with pytest.raises(AccountPasswordError, match="Invalid email or password"): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - mocker.patch("services.webapp_auth_service.compare_password", return_value=True) - - # Act - result = WebAppAuthService.authenticate("user@example.com", "pwd") - - # Assert - assert result is account - - -def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None: - # Arrange - account = _account(id="a1", email="u@example.com") - mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token") - - # Act - result = WebAppAuthService.login(account) - - # Assert - assert result == "jwt-token" - mock_get_token.assert_called_once_with(account=account) - - -def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None: - # Arrange - mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) - - # Act - result = WebAppAuthService.get_user_through_email("missing@example.com") - - # Assert - assert result is None - - -def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.BANNED) - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act + Assert - with pytest.raises(Unauthorized, match="Account is banned"): - WebAppAuthService.get_user_through_email("user@example.com") - - -def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE) - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act - result = WebAppAuthService.get_user_through_email("user@example.com") - - # Assert - assert result is account - - -def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Email must be provided"): - WebAppAuthService.send_email_code_login_email(account=None, email=None) - - -def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account( - mocker: MockerFixture, -) -> None: - # Arrange - account = _account(email="user@example.com") - mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6]) - mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1") - mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") - - # Act - result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") - - # Assert - assert result == "token-1" - mock_generate_token.assert_called_once() - assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"} - mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456") - - -def test_send_email_code_login_email_should_send_mail_for_email_without_account( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0]) - mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2") - mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") - - # Act - result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans") - - # Assert - assert result == "token-2" - mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000") - - -def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None: - # Arrange - mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"}) - - # Act - result = WebAppAuthService.get_email_code_login_data("token-abc") - - # Assert - assert result == {"code": "123"} - mock_get_data.assert_called_once_with("token-abc", "email_code_login") - - -def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None: - # Arrange - mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token") - - # Act - WebAppAuthService.revoke_email_code_login_token("token-xyz") - - # Assert - mock_revoke.assert_called_once_with("token-xyz", "email_code_login") - - -def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(NotFound, match="Site not found"): - WebAppAuthService.create_end_user("app-code", "user@example.com") - - -def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None: - # Arrange - site = SimpleNamespace(app_id="app-1") - app_query = MagicMock() - app_query.where.return_value.first.return_value = None - mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None] - - # Act + Assert - with pytest.raises(NotFound, match="App not found"): - WebAppAuthService.create_end_user("app-code", "user@example.com") - - -def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None: - # Arrange - site = SimpleNamespace(app_id="app-1") - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model] - - # Act - result = WebAppAuthService.create_end_user("app-code", "user@example.com") - - # Assert - assert result.tenant_id == "tenant-1" - assert result.app_id == "app-1" - assert result.session_id == "user@example.com" - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None: - # Arrange - account = _account(id="a1", email="user@example.com") - mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60) - mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1") - - # Act - token = WebAppAuthService._get_account_jwt_token(account) - - # Assert - assert token == "jwt-1" - payload = mock_issue.call_args.args[0] - assert payload["user_id"] == "a1" - assert payload["session_id"] == "user@example.com" - assert payload["token_source"] == "webapp_login_token" - assert payload["auth_type"] == "internal" - assert payload["exp"] > int(datetime.now(UTC).timestamp()) - - -@pytest.mark.parametrize( - ("access_mode", "expected"), - [ - ("private", True), - ("private_all", True), - ("public", False), - ], -) -def test_is_app_require_permission_check_should_use_access_mode_when_provided( - access_mode: str, - expected: bool, -) -> None: - # Arrange - # Act - result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode) - - # Assert - assert result is expected - - -def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Either app_code or app_id must be provided"): - WebAppAuthService.is_app_require_permission_check() - - -def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None) - - # Act + Assert - with pytest.raises(ValueError, match="App ID could not be determined"): - WebAppAuthService.is_app_require_permission_check(app_code="app-code") - - -def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="private"), - ) - - # Act - result = WebAppAuthService.is_app_require_permission_check(app_code="app-code") - - # Assert - assert result is True - - -def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="public"), - ) - - # Act - result = WebAppAuthService.is_app_require_permission_check(app_id="app-1") - - # Assert - assert result is False - - -@pytest.mark.parametrize( - ("access_mode", "expected"), - [ - ("public", WebAppAuthType.PUBLIC), - ("private", WebAppAuthType.INTERNAL), - ("private_all", WebAppAuthType.INTERNAL), - ("sso_verified", WebAppAuthType.EXTERNAL), - ], -) -def test_get_app_auth_type_should_map_access_modes_correctly( - access_mode: str, - expected: WebAppAuthType, -) -> None: - # Arrange - # Act - result = WebAppAuthService.get_app_auth_type(access_mode=access_mode) - - # Assert - assert result == expected - - -def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="private_all"), - ) - - # Act - result = WebAppAuthService.get_app_auth_type(app_code="app-code") - - # Assert - assert result == WebAppAuthType.INTERNAL - - -def test_get_app_auth_type_should_raise_when_no_input_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"): - WebAppAuthService.get_app_auth_type() - - -def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Could not determine app authentication type"): - WebAppAuthService.get_app_auth_type(access_mode="unknown") diff --git a/api/tests/unit_tests/services/test_workflow_app_service.py b/api/tests/unit_tests/services/test_workflow_app_service.py deleted file mode 100644 index fa76521f2d..0000000000 --- a/api/tests/unit_tests/services/test_workflow_app_service.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from dify_graph.enums import WorkflowExecutionStatus -from models import App, WorkflowAppLog -from models.enums import AppTriggerType, CreatorUserRole -from services.workflow_app_service import LogView, WorkflowAppService - - -@pytest.fixture -def service() -> WorkflowAppService: - # Arrange - return WorkflowAppService() - - -@pytest.fixture -def app_model() -> App: - # Arrange - return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) - - -def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog: - return cast(WorkflowAppLog, SimpleNamespace(**kwargs)) - - -def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None: - # Arrange - log = _workflow_app_log(id="log-1", status="succeeded") - view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) - - # Act - details = view.details - proxied_status = view.status - - # Assert - assert details == {"trigger_metadata": {"type": "plugin"}} - assert proxied_status == "succeeded" - - -def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - log_1 = SimpleNamespace(id="log-1") - log_2 = SimpleNamespace(id="log-2") - session.scalar.return_value = 3 - session.scalars.return_value.all.return_value = [log_1, log_2] - - # Act - result = service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - page=1, - limit=2, - detail=False, - ) - - # Assert - assert result["page"] == 1 - assert result["limit"] == 2 - assert result["total"] == 3 - assert result["has_more"] is True - assert len(result["data"]) == 2 - assert isinstance(result["data"][0], LogView) - assert result["data"][0].details is None - - -def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true( - service: WorkflowAppService, - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - session.scalar.side_effect = [1] - log_1 = SimpleNamespace(id="log-1") - session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')] - mock_handle = mocker.patch.object( - service, - "handle_trigger_metadata", - return_value={"type": "trigger_plugin", "icon": "url"}, - ) - - # Act - result = service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - keyword="run-1", - status=WorkflowExecutionStatus.SUCCEEDED, - created_at_before=None, - created_at_after=None, - page=1, - limit=20, - detail=True, - ) - - # Assert - assert result["total"] == 1 - assert len(result["data"]) == 1 - assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}} - mock_handle.assert_called_once() - - -def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - session.scalar.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Account not found: account@example.com"): - service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - created_by_account="account@example.com", - ) - - -def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0] - session.scalars.return_value.all.return_value = [] - - # Act - result = service.get_paginate_workflow_app_logs( - session=session, - app_model=app_model, - created_by_account="account@example.com", - ) - - # Assert - assert result["total"] == 0 - assert result["data"] == [] - - -def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items( - service: WorkflowAppService, - app_model: App, -) -> None: - # Arrange - session = MagicMock() - log_account = SimpleNamespace( - id="log-1", - created_by="acc-1", - created_by_role=CreatorUserRole.ACCOUNT, - workflow_run_summary={"run": "1"}, - trigger_metadata='{"type":"trigger-webhook"}', - log_created_at="2026-01-01", - ) - log_end_user = SimpleNamespace( - id="log-2", - created_by="end-1", - created_by_role=CreatorUserRole.END_USER, - workflow_run_summary={"run": "2"}, - trigger_metadata='{"type":"trigger-webhook"}', - log_created_at="2026-01-02", - ) - log_unknown = SimpleNamespace( - id="log-3", - created_by="other", - created_by_role="system", - workflow_run_summary={"run": "3"}, - trigger_metadata='{"type":"trigger-webhook"}', - log_created_at="2026-01-03", - ) - session.scalar.return_value = 3 - session.scalars.side_effect = [ - SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]), - SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]), - SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]), - ] - - # Act - result = service.get_paginate_workflow_archive_logs( - session=session, - app_model=app_model, - page=1, - limit=20, - ) - - # Assert - assert result["total"] == 3 - assert len(result["data"]) == 3 - assert result["data"][0]["created_by_account"].id == "acc-1" - assert result["data"][1]["created_by_end_user"].id == "end-1" - assert result["data"][2]["created_by_account"] is None - assert result["data"][2]["created_by_end_user"] is None - - -def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing( - service: WorkflowAppService, -) -> None: - # Arrange - # Act - result = service.handle_trigger_metadata("tenant-1", None) - - # Assert - assert result == {} - - -def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin( - service: WorkflowAppService, - mocker: MockerFixture, -) -> None: - # Arrange - meta = { - "type": AppTriggerType.TRIGGER_PLUGIN.value, - "icon_filename": "light.png", - "icon_dark_filename": "dark.png", - } - mock_icon = mocker.patch( - "services.workflow_app_service.PluginService.get_plugin_icon_url", - side_effect=["https://cdn/light.png", "https://cdn/dark.png"], - ) - - # Act - result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) - - # Assert - assert result["icon"] == "https://cdn/light.png" - assert result["icon_dark"] == "https://cdn/dark.png" - assert mock_icon.call_count == 2 - - -def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup( - service: WorkflowAppService, - mocker: MockerFixture, -) -> None: - # Arrange - meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} - mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url") - - # Act - result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) - - # Assert - assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value - mock_icon.assert_not_called() - - -@pytest.mark.parametrize( - ("value", "expected"), - [ - (None, None), - ("", None), - ('{"k":"v"}', {"k": "v"}), - ("not-json", None), - ({"raw": True}, {"raw": True}), - ], -) -def test_safe_json_loads_should_handle_various_inputs( - value: object, - expected: object, - service: WorkflowAppService, -) -> None: - # Arrange - # Act - result = service._safe_json_loads(value) - - # Assert - assert result == expected - - -def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None: - # Arrange - # Act - short_result = service._safe_parse_uuid("short") - invalid_result = service._safe_parse_uuid("x" * 40) - - # Assert - assert short_result is None - assert invalid_result is None - - -def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None: - # Arrange - raw_uuid = str(uuid.uuid4()) - - # Act - result = service._safe_parse_uuid(raw_uuid) - - # Assert - assert result is not None - assert str(result) == raw_uuid diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 27664c7e29..239cc83518 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -16,7 +16,7 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index d26c2f674f..da606c8329 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -12,22 +12,23 @@ This test suite covers: import json import uuid from typing import Any, cast -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import ( +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from dify_graph.variables.input_entities import VariableEntityType +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.node_events import NodeRunResult +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.variables.input_entities import VariableEntityType from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from models.model import App, AppMode @@ -1545,7 +1546,10 @@ class TestWorkflowServiceCredentialValidation: def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: """If ModelManager raises any exception it must be wrapped into ValueError.""" # Arrange - with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")): + assembly = MagicMock() + assembly.model_manager.get_model_instance.side_effect = RuntimeError("no key") + + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act + Assert with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): service._validate_llm_model_config("tenant-1", "openai", "gpt-4") @@ -1558,30 +1562,30 @@ class TestWorkflowServiceCredentialValidation: mock_configs = MagicMock() mock_configs.get_models.return_value = [mock_model] + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs - with ( - patch("core.model_manager.ModelManager.get_model_instance"), - patch("core.provider_manager.ProviderManager") as mock_pm_cls, - ): - mock_pm_cls.return_value.get_configurations.return_value = mock_configs - + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act service._validate_llm_model_config("tenant-1", "openai", "gpt-4") # Assert mock_model.raise_for_status.assert_called_once() + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4", + ) def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: """Test ValueError when model is not found in provider configurations.""" mock_configs = MagicMock() mock_configs.get_models.return_value = [] # No models + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs - with ( - patch("core.model_manager.ModelManager.get_model_instance"), - patch("core.provider_manager.ProviderManager") as mock_pm_cls, - ): - mock_pm_cls.return_value.get_configurations.return_value = mock_configs - + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act + Assert with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): service._validate_llm_model_config("tenant-1", "openai", "gpt-4") @@ -2053,22 +2057,37 @@ class TestSetupVariablePool: workflow = self._make_workflow() # Act - with patch("services.workflow_service.VariablePool") as MockPool: + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, + ): _setup_variable_pool( query="hello", files=[], user_id="u-1", user_inputs={"k": "v"}, workflow=workflow, + node_id="start-node", node_type=BuiltinNodeTypes.START, conversation_id="conv-1", conversation_variables=[], ) - # Assert — VariablePool should be called with a SystemVariable (non-default) - MockPool.assert_called_once() - call_kwargs = MockPool.call_args.kwargs - assert call_kwargs["user_inputs"] == {"k": "v"} + # Assert — start nodes should build bootstrap variables and attach node inputs. + MockPool.assert_called_once_with() + mock_build_system_variables.assert_called_once() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_called_once_with( + MockPool.return_value, + node_id="start-node", + inputs={"k": "v"}, + ) def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( self, @@ -2079,7 +2098,10 @@ class TestSetupVariablePool: # Act with ( patch("services.workflow_service.VariablePool") as MockPool, - patch("services.workflow_service.SystemVariable.default") as mock_default, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, ): _setup_variable_pool( query="", @@ -2087,14 +2109,20 @@ class TestSetupVariablePool: user_id="u-1", user_inputs={}, workflow=workflow, + node_id="llm-node", node_type=BuiltinNodeTypes.LLM, # not a start/trigger node conversation_id="conv-1", conversation_variables=[], ) - # Assert — SystemVariable.default() should be used for non-start nodes - mock_default.assert_called_once() - MockPool.assert_called_once() + # Assert — default system variables should be used and node inputs should not be added. + mock_default_system_variables.assert_called_once() + MockPool.assert_called_once_with() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_not_called() def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( self, @@ -2106,20 +2134,31 @@ class TestSetupVariablePool: workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) # Act - with patch("services.workflow_service.VariablePool") as MockPool: + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables"), + patch("services.workflow_service.add_variables_to_pool"), + patch("services.workflow_service.add_node_inputs_to_pool"), + ): _setup_variable_pool( query="what is AI?", files=[], user_id="u-1", user_inputs={}, workflow=workflow, + node_id="start-node", node_type=BuiltinNodeTypes.START, conversation_id="conv-abc", conversation_variables=[], ) - # Assert — we just verify VariablePool was called (chatflow path executed) - MockPool.assert_called_once() + # Assert — chatflow system variables should include query, conversation_id and dialogue_count. + MockPool.assert_called_once_with() + system_variable_values = mock_build_system_variables.call_args.args[0] + assert system_variable_values["query"] == "what is AI?" + assert system_variable_values["conversation_id"] == "conv-abc" + assert system_variable_values["dialogue_count"] == 1 class TestRebuildSingleFile: @@ -2142,7 +2181,7 @@ class TestRebuildSingleFile: # Assert assert result is mock_file - mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY) def test_rebuild_single_file_should_raise_when_file_value_not_dict( self, @@ -2165,7 +2204,7 @@ class TestRebuildSingleFile: # Assert assert result is mock_files - mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY) def test_rebuild_single_file_should_raise_when_file_list_value_not_list( self, @@ -2279,13 +2318,12 @@ class TestWorkflowServiceResolveDeliveryMethod: # Arrange method_a = self._make_method("method-1") method_b = self._make_method("method-2") - node_data = MagicMock() - node_data.delivery_methods = [method_a, method_b] # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="method-2" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-2" + ) # Assert assert result is method_b @@ -2293,26 +2331,22 @@ class TestWorkflowServiceResolveDeliveryMethod: def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: # Arrange method_a = self._make_method("method-1") - node_data = MagicMock() - node_data.delivery_methods = [method_a] # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="does-not-exist" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="does-not-exist" + ) # Assert assert result is None def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: - # Arrange - node_data = MagicMock() - node_data.delivery_methods = [] - # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="method-1" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-1" + ) # Assert assert result is None @@ -2435,6 +2469,9 @@ class TestWorkflowServiceDraftExecution: patch("services.workflow_service.Session"), patch("services.workflow_service.WorkflowDraftVariableService"), patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, patch("services.workflow_service.DraftVarLoader"), patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, patch("services.workflow_service.DifyCoreRepositoryFactory"), @@ -2475,10 +2512,16 @@ class TestWorkflowServiceDraftExecution: ) # Assert - # For non-start nodes, VariablePool should be initialized with environment_variables - mock_pool_cls.assert_called_once() - args, kwargs = mock_pool_cls.call_args - assert "environment_variables" in kwargs + # For non-start nodes, bootstrap variables should be loaded into an empty pool. + mock_pool_cls.assert_called_once_with() + mock_default_system_variables.assert_called_once() + mock_build_bootstrap_variables.assert_called_once_with( + system_variables=mock_default_system_variables.return_value, + environment_variables=draft_workflow.environment_variables, + ) + mock_add_variables_to_pool.assert_called_once_with( + mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value + ) # =========================================================================== @@ -2588,7 +2631,7 @@ class TestWorkflowServiceHumanInputOperations: patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), patch("services.workflow_service.HumanInputNodeData.model_validate"), patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, - patch("services.workflow_service.apply_debug_email_recipient"), + patch("services.workflow_service.apply_dify_debug_email_recipient"), patch.object(service, "_build_human_input_variable_pool"), patch.object(service, "_build_human_input_node"), patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), @@ -2708,7 +2751,7 @@ class TestWorkflowServiceFreeNodeExecution: def test_validate_human_input_node_data_error(self, service: WorkflowService) -> None: with patch( - "dify_graph.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") + "graphon.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") ): with pytest.raises(ValueError, match="Invalid HumanInput node data"): service._validate_human_input_node_data({}) @@ -2730,13 +2773,15 @@ class TestWorkflowServiceFreeNodeExecution: variable_pool = MagicMock() with ( - patch("services.workflow_service.GraphInitParams"), + patch("services.workflow_service.GraphInitParams") as mock_graph_init_params, patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.build_dify_run_context"), + patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls, - patch("services.workflow_service.HumanInputFormRepositoryImpl"), ): node = service._build_human_input_node( workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool ) assert node == mock_node_cls.return_value mock_node_cls.assert_called_once() + mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context) diff --git a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py deleted file mode 100644 index ce44818886..0000000000 --- a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py +++ /dev/null @@ -1,643 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.tools.entities.tool_entities import ApiProviderSchemaType -from services.tools.api_tools_manage_service import ApiToolManageService - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.tools.api_tools_manage_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace: - return SimpleNamespace(operation_id=operation_id) - - -def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value), - ) - - # Act - result = ApiToolManageService.parser_api_schema("valid-schema") - - # Assert - assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value - assert len(result["credentials_schema"]) == 3 - assert "warning" in result - - -def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("bad schema"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"): - ApiToolManageService.parser_api_schema("invalid") - - -def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None: - # Arrange - expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER) - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=expected, - ) - extra_info: dict[str, str] = {} - - # Act - result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info) - - # Assert - assert result == expected - - -def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=ValueError("parse failed"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: parse failed"): - ApiToolManageService.convert_schema_to_tool_bundles("schema") - - -def test_create_api_tool_provider_should_raise_error_when_provider_already_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="provider provider-a already exists"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name=" provider-a ", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - many_tools = [_tool_bundle(str(i)) for i in range(101)] - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=(many_tools, ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="the number of apis should be less than 100"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_create_provider_when_input_is_valid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - mock_controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=mock_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.encrypt.return_value = {"auth_type": "none"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - mock_controller.load_bundled_tools.assert_called_once() - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=200, text="schema-content"), - ) - mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True}) - - # Act - result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - - # Assert - assert result == {"schema": "schema-content"} - - -@pytest.mark.parametrize("status_code", [400, 404, 500]) -def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid( - status_code: int, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=status_code, text="schema-content"), - ) - mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger") - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema, please check the url you provided"): - ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - mock_logger.exception.assert_called_once() - - -def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - -def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")]) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - return_value=controller, - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"]) - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}], - ) - - # Act - result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - # Assert - assert result == [{"name": "tool-a"}, {"name": "tool-b"}] - assert mock_convert.call_count == 2 - - -def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="api provider provider-a does not exists"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={"auth_type": "none"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(credentials={}, name="old") - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace( - credentials={"auth_type": "none", "api_key_value": "encrypted-old"}, - name="old", - icon="", - schema="", - description="", - schema_type_str="", - tools_str="", - privacy_policy="", - custom_disclaimer="", - credentials_str="", - ) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=controller, - ) - cache = MagicMock() - encrypter = MagicMock() - encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"} - encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} - encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(encrypter, cache), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-new", - original_provider="provider-old", - icon={"emoji": "E"}, - credentials={"auth_type": "none", "api_key_value": "***"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - assert provider.name == "provider-new" - assert provider.privacy_policy == "privacy" - assert provider.credentials_str != "" - cache.delete.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - -def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None: - # Arrange - provider = object() - mock_db.session.query.return_value.where.return_value.first.return_value = provider - - # Act - result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == {"result": "success"} - mock_db.session.delete.assert_called_once_with(provider) - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None: - # Arrange - expected = {"provider": "value"} - mock_get = mocker.patch( - "services.tools.api_tools_manage_service.ToolManager.user_get_api_provider", - return_value=expected, - ) - - # Act - result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == expected - mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1") - - -def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None: - # Arrange - schema_type = "bad-schema-type" - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema type"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=schema_type, # type: ignore[arg-type] - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("invalid"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="invalid tool name tool-b"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-b", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.side_effect = ValueError("validation failed") - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"error": "validation failed"} - - -def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.return_value = {"ok": True} - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={"x": "1"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"result": {"ok": True}} - - -def test_list_api_tools_should_return_all_user_providers_with_converted_tools( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider_one = SimpleNamespace(name="p1") - provider_two = SimpleNamespace(name="p2") - mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two] - - controller_one = MagicMock() - controller_one.get_tools.return_value = ["tool-a"] - controller_two = MagicMock() - controller_two.get_tools.return_value = ["tool-b", "tool-c"] - - user_provider_one = SimpleNamespace(labels=[], tools=[]) - user_provider_two = SimpleNamespace(labels=[], tools=[]) - - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - side_effect=[controller_one, controller_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"]) - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider", - side_effect=[user_provider_one, user_provider_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider") - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}], - ) - - # Act - result = ApiToolManageService.list_api_tools("tenant-1") - - # Assert - assert len(result) == 2 - assert user_provider_one.tools == [{"name": "tool-a"}] - assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}] - assert mock_convert.call_count == 3 diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py deleted file mode 100644 index d35e014fab..0000000000 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py +++ /dev/null @@ -1,1045 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -from datetime import datetime -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from sqlalchemy.exc import IntegrityError - -from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity -from core.mcp.entities import AuthActionType -from core.mcp.error import MCPAuthError, MCPError -from models.tools import MCPToolProvider -from services.tools.mcp_tools_manage_service import ( - EMPTY_CREDENTIALS_JSON, - EMPTY_TOOLS_JSON, - UNCHANGED_SERVER_URL_PLACEHOLDER, - MCPToolManageService, - OAuthDataType, - ProviderUrlValidationData, - ReconnectResult, - ServerUrlValidationResult, -) - - -class _ToolStub: - def __init__(self, name: str, description: str | None) -> None: - self._name = name - self._description = description - - def model_dump(self) -> dict[str, str | None]: - return {"name": self._name, "description": self._description} - - -@pytest.fixture -def mock_session() -> MagicMock: - # Arrange - return MagicMock() - - -@pytest.fixture -def service(mock_session: MagicMock) -> MCPToolManageService: - # Arrange - return MCPToolManageService(session=mock_session) - - -def _provider_entity_stub(*, authed: bool = True) -> MCPProviderEntity: - return cast( - MCPProviderEntity, - SimpleNamespace( - authed=authed, - timeout=30.0, - sse_read_timeout=300.0, - provider_id="server-1", - headers={"x-api-key": "enc"}, - decrypt_headers=lambda: {"x-api-key": "key"}, - retrieve_tokens=lambda: SimpleNamespace(token_type="bearer", access_token="token-1"), - decrypt_server_url=lambda: "https://mcp.example.com/sse", - to_api_response=lambda user_name=None: { - "id": "provider-1", - "author": user_name or "Anonymous", - "name": "MCP Tool", - "description": {"en_US": "", "zh_Hans": ""}, - "icon": "icon", - "label": {"en_US": "MCP Tool", "zh_Hans": "MCP Tool"}, - "type": "mcp", - "is_team_authorization": True, - "server_url": "https://mcp.example.com/******", - "updated_at": 1, - "server_identifier": "server-1", - "configuration": {"timeout": "30", "sse_read_timeout": "300"}, - "masked_headers": {}, - "is_dynamic_registration": True, - }, - decrypt_credentials=lambda: {"client_id": "plain-id", "client_secret": "plain-secret"}, - masked_credentials=lambda: {"client_id": "pl***id", "client_secret": "pl***et"}, - masked_headers=lambda: {"x-api-key": "ke***ey"}, - ), - ) - - -def _provider_stub(*, authed: bool = True) -> MCPToolProvider: - entity = _provider_entity_stub(authed=authed) - return cast( - MCPToolProvider, - SimpleNamespace( - id="provider-1", - tenant_id="tenant-1", - user_id="user-1", - name="Provider A", - server_identifier="server-1", - server_url="encrypted-url", - server_url_hash="old-hash", - authed=authed, - tools=EMPTY_TOOLS_JSON, - encrypted_credentials=json.dumps({"existing": "credential"}), - encrypted_headers=json.dumps({"x-api-key": "enc"}), - credentials={"existing": "credential"}, - timeout=30.0, - sse_read_timeout=300.0, - updated_at=datetime.now(), - icon="icon", - to_entity=lambda: entity, - load_user=lambda: SimpleNamespace(name="Tester"), - ), - ) - - -def test_server_url_validation_result_should_update_server_url_when_all_conditions_match() -> None: - # Arrange - result = ServerUrlValidationResult( - needs_validation=True, - validation_passed=True, - reconnect_result=ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}"), - ) - - # Act - should_update = result.should_update_server_url - - # Assert - assert should_update is True - - -def test_get_provider_should_return_provider_when_exists( - service: MCPToolManageService, - mock_session: MagicMock, -) -> None: - # Arrange - provider = _provider_stub() - mock_session.scalar.return_value = provider - - # Act - result = service.get_provider(provider_id="provider-1", tenant_id="tenant-1") - - # Assert - assert result is provider - - -def test_get_provider_should_raise_error_when_provider_not_found( - service: MCPToolManageService, mock_session: MagicMock -) -> None: - # Arrange - mock_session.scalar.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="MCP tool not found"): - service.get_provider(provider_id="provider-404", tenant_id="tenant-1") - - -def test_get_provider_entity_should_get_entity_by_provider_id_when_by_server_id_is_false( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - result = service.get_provider_entity("provider-1", "tenant-1", by_server_id=False) - - # Assert - assert result is provider.to_entity() - mock_get_provider.assert_called_once_with(provider_id="provider-1", tenant_id="tenant-1") - - -def test_get_provider_entity_should_get_entity_by_server_identifier_when_by_server_id_is_true( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - result = service.get_provider_entity("server-1", "tenant-1", by_server_id=True) - - # Assert - assert result is provider.to_entity() - mock_get_provider.assert_called_once_with(server_identifier="server-1", tenant_id="tenant-1") - - -def test_create_provider_should_raise_error_when_server_url_is_invalid(service: MCPToolManageService) -> None: - # Arrange - config = MCPConfiguration(timeout=30, sse_read_timeout=300) - - # Act + Assert - with pytest.raises(ValueError, match="Server URL is not valid"): - service.create_provider( - tenant_id="tenant-1", - name="Provider A", - server_url="invalid-url", - user_id="user-1", - icon="icon", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=config, - ) - - -def test_create_provider_should_create_and_return_user_provider_when_input_is_valid( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - config = MCPConfiguration(timeout=42, sse_read_timeout=123) - auth_data = MCPAuthentication(client_id="client-id", client_secret="secret") - mocker.patch.object(service, "_check_provider_exists") - mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="encrypted-url") - mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x":"enc"}') - mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') - mocker.patch.object(service, "_prepare_icon", return_value='{"content":"😀"}') - expected_user_provider = {"id": "provider-1"} - mock_convert = mocker.patch( - "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", - return_value=expected_user_provider, - ) - - # Act - result = service.create_provider( - tenant_id="tenant-1", - name="Provider A", - server_url="https://mcp.example.com", - user_id="user-1", - icon="😀", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=config, - authentication=auth_data, - headers={"x-api-key": "v1"}, - ) - - # Assert - assert result == expected_user_provider - mock_session.add.assert_called_once() - mock_session.flush.assert_called_once() - mock_convert.assert_called_once() - - -def test_update_provider_should_raise_error_when_new_name_conflicts( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - mock_session.scalar.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - service.update_provider( - tenant_id="tenant-1", - provider_id="provider-1", - name="New Name", - server_url="https://mcp.example.com", - icon="😀", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=MCPConfiguration(), - ) - - -def test_update_provider_should_update_fields_when_input_is_valid( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - validation = ServerUrlValidationResult( - needs_validation=True, - validation_passed=True, - reconnect_result=ReconnectResult(authed=True, tools='[{"name":"t"}]', encrypted_credentials='{"x":"y"}'), - encrypted_server_url="new-encrypted-url", - server_url_hash="new-hash", - ) - mocker.patch.object(service, "get_provider", return_value=provider) - mock_session.scalar.return_value = None - mocker.patch.object(service, "_prepare_icon", return_value="new-icon") - mocker.patch.object(service, "_process_headers", return_value='{"x":"enc"}') - mocker.patch.object(service, "_process_credentials", return_value='{"client":"enc"}') - - # Act - service.update_provider( - tenant_id="tenant-1", - provider_id="provider-1", - name="Provider B", - server_url="https://mcp.example.com/new", - icon="😎", - icon_type="emoji", - icon_background="#000", - server_identifier="server-2", - headers={"x-api-key": "v2"}, - configuration=MCPConfiguration(timeout=50, sse_read_timeout=120), - authentication=MCPAuthentication(client_id="new-id", client_secret="new-secret"), - validation_result=validation, - ) - - # Assert - assert provider.name == "Provider B" - assert provider.server_identifier == "server-2" - assert provider.server_url == "new-encrypted-url" - assert provider.server_url_hash == "new-hash" - assert provider.authed is True - assert provider.tools == '[{"name":"t"}]' - assert provider.encrypted_credentials == '{"client":"enc"}' - assert provider.encrypted_headers == '{"x":"enc"}' - assert provider.timeout == 50 - assert provider.sse_read_timeout == 120 - mock_session.flush.assert_called_once() - - -def test_update_provider_should_handle_integrity_error_with_readable_message( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - mock_session.scalar.return_value = None - mocker.patch.object(service, "_prepare_icon", return_value="icon") - mock_session.flush.side_effect = IntegrityError("stmt", {}, Exception("unique_mcp_provider_name")) - - # Act + Assert - with pytest.raises(ValueError, match="MCP tool Provider A already exists"): - service.update_provider( - tenant_id="tenant-1", - provider_id="provider-1", - name="Provider A", - server_url="https://mcp.example.com", - icon="😀", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=MCPConfiguration(), - ) - - -def test_delete_provider_should_delete_existing_provider( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - service.delete_provider(tenant_id="tenant-1", provider_id="provider-1") - - # Assert - mock_session.delete.assert_called_once_with(provider) - - -def test_list_providers_should_return_empty_list_when_no_provider_exists( - service: MCPToolManageService, - mock_session: MagicMock, -) -> None: - # Arrange - mock_session.scalars.return_value.all.return_value = [] - - # Act - result = service.list_providers(tenant_id="tenant-1") - - # Assert - assert result == [] - - -def test_list_providers_should_convert_all_providers_and_attach_user_names( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider_1 = _provider_stub() - provider_2 = _provider_stub() - provider_2.user_id = "user-2" - mock_session.scalars.return_value.all.return_value = [provider_1, provider_2] - mock_session.query.return_value.where.return_value.all.return_value = [ - SimpleNamespace(id="user-1", name="Alice"), - SimpleNamespace(id="user-2", name="Bob"), - ] - mock_convert = mocker.patch( - "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", - side_effect=[{"id": "1"}, {"id": "2"}], - ) - - # Act - result = service.list_providers(tenant_id="tenant-1", for_list=True, include_sensitive=False) - - # Assert - assert result == [{"id": "1"}, {"id": "2"}] - assert mock_convert.call_count == 2 - - -def test_list_provider_tools_should_raise_error_when_provider_is_not_authenticated( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=False) - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act + Assert - with pytest.raises(ValueError, match="Please auth the tool first"): - service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") - - -def test_list_provider_tools_should_raise_error_when_remote_client_fails( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - mocker.patch.object(service, "get_provider", return_value=provider) - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.side_effect = MCPError("connection failed") - mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act + Assert - with pytest.raises(ValueError, match="Failed to connect to MCP server"): - service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") - - -def test_list_provider_tools_should_update_db_and_return_response_on_success( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - mocker.patch.object(service, "get_provider", return_value=provider) - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.return_value = [ - _ToolStub("tool-a", None), - _ToolStub("tool-b", "desc"), - ] - mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) - - # Act - result = service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") - - # Assert - assert result.plugin_unique_identifier == "server-1" - assert provider.authed is True - payload = json.loads(provider.tools) - assert payload[0]["description"] == "" - assert payload[1]["description"] == "desc" - mock_session.flush.assert_called_once() - - -def test_update_provider_credentials_should_update_encrypted_credentials_and_auth_state( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - provider.encrypted_credentials = json.dumps({"existing": "value"}) - mocker.patch.object(service, "get_provider", return_value=provider) - mock_controller = MagicMock() - mocker.patch("core.tools.mcp_tool.provider.MCPToolProviderController.from_db", return_value=mock_controller) - mock_encryptor = MagicMock() - mock_encryptor.encrypt.return_value = {"access_token": "encrypted-token"} - mocker.patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter", return_value=mock_encryptor) - - # Act - service.update_provider_credentials( - provider_id="provider-1", - tenant_id="tenant-1", - credentials={"access_token": "plain-token"}, - authed=False, - ) - - # Assert - assert provider.authed is False - assert provider.tools == EMPTY_TOOLS_JSON - assert json.loads(cast(str, provider.encrypted_credentials))["access_token"] == "encrypted-token" - mock_session.flush.assert_called_once() - - -@pytest.mark.parametrize( - ("data_type", "data", "expected_authed"), - [ - (OAuthDataType.TOKENS, {"access_token": "token"}, True), - (OAuthDataType.MIXED, {"access_token": "token"}, True), - (OAuthDataType.MIXED, {"client_id": "id"}, None), - (OAuthDataType.CLIENT_INFO, {"client_id": "id"}, None), - ], -) -def test_save_oauth_data_should_delegate_with_expected_authed_value( - data_type: OAuthDataType, - data: dict[str, str], - expected_authed: bool | None, - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mock_update = mocker.patch.object(service, "update_provider_credentials") - - # Act - service.save_oauth_data("provider-1", "tenant-1", data, data_type) - - # Assert - assert mock_update.call_args.kwargs["authed"] == expected_authed - - -def test_clear_provider_credentials_should_reset_provider_state( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - service.clear_provider_credentials(provider_id="provider-1", tenant_id="tenant-1") - - # Assert - assert provider.tools == EMPTY_TOOLS_JSON - assert provider.encrypted_credentials == EMPTY_CREDENTIALS_JSON - assert provider.authed is False - - -def test_check_provider_exists_should_raise_different_errors_for_conflicts( - service: MCPToolManageService, - mock_session: MagicMock, -) -> None: - # Arrange - mock_session.scalar.return_value = SimpleNamespace( - name="name-a", - server_url_hash="hash-a", - server_identifier="server-a", - ) - - # Act + Assert - with pytest.raises(ValueError, match="MCP tool name-a already exists"): - service._check_provider_exists("tenant-1", "name-a", "hash-b", "server-b") - with pytest.raises(ValueError, match="MCP tool with this server URL already exists"): - service._check_provider_exists("tenant-1", "name-b", "hash-a", "server-b") - with pytest.raises(ValueError, match="MCP tool server-a already exists"): - service._check_provider_exists("tenant-1", "name-b", "hash-b", "server-a") - - -def test_prepare_icon_should_return_json_for_emoji_and_raw_value_for_non_emoji(service: MCPToolManageService) -> None: - # Arrange - # Act - emoji_icon = service._prepare_icon("😀", "emoji", "#fff") - raw_icon = service._prepare_icon("https://icon.png", "file", "#000") - - # Assert - assert json.loads(emoji_icon)["content"] == "😀" - assert raw_icon == "https://icon.png" - - -def test_encrypt_dict_fields_should_encrypt_secret_fields(service: MCPToolManageService, mocker: MockerFixture) -> None: - # Arrange - mock_encryptor = MagicMock() - mock_encryptor.encrypt.return_value = {"Authorization": "enc-token"} - mocker.patch("core.tools.utils.encryption.create_provider_encrypter", return_value=(mock_encryptor, MagicMock())) - - # Act - result = service._encrypt_dict_fields({"Authorization": "token"}, ["Authorization"], "tenant-1") - - # Assert - assert result == {"Authorization": "enc-token"} - - -def test_prepare_encrypted_dict_should_return_json_string(service: MCPToolManageService, mocker: MockerFixture) -> None: - # Arrange - mocker.patch.object(service, "_encrypt_dict_fields", return_value={"x": "enc"}) - - # Act - result = service._prepare_encrypted_dict({"x": "v"}, "tenant-1") - - # Assert - assert result == '{"x": "enc"}' - - -def test_prepare_auth_headers_should_append_authorization_when_tokens_exist(service: MCPToolManageService) -> None: - # Arrange - provider_entity = _provider_entity_stub() - - # Act - headers = service._prepare_auth_headers(provider_entity) - - # Assert - assert headers["Authorization"] == "Bearer token-1" - - -def test_retrieve_remote_mcp_tools_should_return_tools_from_client( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", "desc")] - mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act - tools = service._retrieve_remote_mcp_tools("https://mcp.example.com", {}, _provider_entity_stub()) - - # Assert - assert len(tools) == 1 - assert tools[0].model_dump()["name"] == "tool-a" - - -def test_execute_auth_actions_should_dispatch_supported_actions( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mock_save = mocker.patch.object(service, "save_oauth_data") - auth_result = SimpleNamespace( - actions=[ - SimpleNamespace( - action_type=AuthActionType.SAVE_CLIENT_INFO, - data={"client_id": "c1"}, - provider_id="provider-1", - tenant_id="tenant-1", - ), - SimpleNamespace( - action_type=AuthActionType.SAVE_TOKENS, - data={"access_token": "t1"}, - provider_id="provider-1", - tenant_id="tenant-1", - ), - SimpleNamespace( - action_type=AuthActionType.SAVE_CODE_VERIFIER, - data={"code_verifier": "cv"}, - provider_id="provider-1", - tenant_id="tenant-1", - ), - SimpleNamespace( - action_type=AuthActionType.SAVE_TOKENS, - data={"access_token": "skip"}, - provider_id=None, - tenant_id="tenant-1", - ), - ], - response={"ok": "1"}, - ) - - # Act - result = service.execute_auth_actions(auth_result) - - # Assert - assert result == {"ok": "1"} - assert mock_save.call_count == 3 - - -def test_auth_with_actions_should_call_auth_and_execute_actions( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider_entity = _provider_entity_stub() - auth_result = SimpleNamespace(actions=[], response={"status": "ok"}) - mocker.patch("services.tools.mcp_tools_manage_service.auth", return_value=auth_result) - mock_execute = mocker.patch.object(service, "execute_auth_actions", return_value={"status": "ok"}) - - # Act - result = service.auth_with_actions(provider_entity=provider_entity, authorization_code="code-1") - - # Assert - assert result == {"status": "ok"} - mock_execute.assert_called_once_with(auth_result) - - -def test_get_provider_for_url_validation_should_return_validation_data( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - result = service.get_provider_for_url_validation(tenant_id="tenant-1", provider_id="provider-1") - - # Assert - assert result.current_server_url_hash == "old-hash" - assert result.headers == {"x-api-key": "enc"} - - -def test_validate_server_url_standalone_should_skip_validation_for_unchanged_placeholder() -> None: - # Arrange - data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) - - # Act - result = MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url=UNCHANGED_SERVER_URL_PLACEHOLDER, - validation_data=data, - ) - - # Assert - assert result.needs_validation is False - - -def test_validate_server_url_standalone_should_raise_error_for_invalid_url() -> None: - # Arrange - data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) - - # Act + Assert - with pytest.raises(ValueError, match="Server URL is not valid"): - MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url="bad-url", - validation_data=data, - ) - - -def test_validate_server_url_standalone_should_return_no_validation_when_hash_unchanged(mocker: MockerFixture) -> None: - # Arrange - url = "https://mcp.example.com" - current_hash = hashlib.sha256(url.encode()).hexdigest() - data = ProviderUrlValidationData(current_server_url_hash=current_hash, headers={}, timeout=30, sse_read_timeout=300) - mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-url") - - # Act - result = MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url=url, - validation_data=data, - ) - - # Assert - assert result.needs_validation is False - assert result.encrypted_server_url == "enc-url" - assert result.server_url_hash == current_hash - - -def test_validate_server_url_standalone_should_reconnect_when_url_changes(mocker: MockerFixture) -> None: - # Arrange - url = "https://mcp-new.example.com" - data = ProviderUrlValidationData(current_server_url_hash="old", headers={}, timeout=30, sse_read_timeout=300) - reconnect_result = ReconnectResult(authed=True, tools='[{"name":"x"}]', encrypted_credentials="{}") - mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-new") - mock_reconnect = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=reconnect_result) - - # Act - result = MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url=url, - validation_data=data, - ) - - # Assert - assert result.validation_passed is True - assert result.reconnect_result == reconnect_result - mock_reconnect.assert_called_once() - - -def test_reconnect_with_url_should_delegate_to_private_method(mocker: MockerFixture) -> None: - # Arrange - expected = ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}") - mock_delegate = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=expected) - - # Act - result = MCPToolManageService.reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - # Assert - assert result == expected - mock_delegate.assert_called_once() - - -def test_private_reconnect_with_url_should_return_authed_true_when_connection_succeeds(mocker: MockerFixture) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", None)] - mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act - result = MCPToolManageService._reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - # Assert - assert result.authed is True - assert json.loads(result.tools)[0]["description"] == "" - - -def test_private_reconnect_with_url_should_return_authed_false_on_auth_error(mocker: MockerFixture) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.side_effect = MCPAuthError("auth required") - mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act - result = MCPToolManageService._reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - # Assert - assert result.authed is False - assert result.tools == EMPTY_TOOLS_JSON - - -def test_private_reconnect_with_url_should_raise_value_error_on_mcp_error(mocker: MockerFixture) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.side_effect = MCPError("network failure") - mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act + Assert - with pytest.raises(ValueError, match="Failed to re-connect MCP server: network failure"): - MCPToolManageService._reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - -def test_build_tool_provider_response_should_build_api_entity_with_tools( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = _provider_stub() - provider_entity = _provider_entity_stub() - tools = [_ToolStub("tool-a", "desc")] - mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) - - # Act - result = service._build_tool_provider_response(db_provider, provider_entity, tools) - - # Assert - assert result.plugin_unique_identifier == "server-1" - assert result.name == "MCP Tool" - - -@pytest.mark.parametrize( - ("orig_message", "expected_error"), - [ - ("unique_mcp_provider_name", "MCP tool name already exists"), - ("unique_mcp_provider_server_url", "MCP tool https://mcp.example.com already exists"), - ("unique_mcp_provider_server_identifier", "MCP tool server-1 already exists"), - ], -) -def test_handle_integrity_error_should_raise_readable_value_errors( - orig_message: str, - expected_error: str, - service: MCPToolManageService, -) -> None: - """Test that known integrity errors raise readable value errors.""" - # Arrange - error = IntegrityError("stmt", {}, Exception(orig_message)) - - # Act + Assert - with pytest.raises(ValueError, match=expected_error): - service._handle_integrity_error(error, "name", "https://mcp.example.com", "server-1") - - -def test_handle_integrity_error_should_reraise_unknown_error(service: MCPToolManageService) -> None: - """Test that unknown integrity errors are re-raised.""" - # Arrange - error = IntegrityError("stmt", {}, Exception("unknown-constraint")) - - # Act + Assert - with pytest.raises(IntegrityError) as exc_info: - service._handle_integrity_error(error, "name", "url", "identifier") - - assert exc_info.value is error - - -@pytest.mark.parametrize( - ("url", "expected"), - [ - ("https://mcp.example.com", True), - ("http://mcp.example.com", True), - ("", False), - ("invalid", False), - ("ftp://mcp.example.com", False), - ], -) -def test_is_valid_url_should_validate_supported_schemes( - url: str, - expected: bool, - service: MCPToolManageService, -) -> None: - # Arrange - # Act - result = service._is_valid_url(url) - - # Assert - assert result is expected - - -def test_update_optional_fields_should_update_only_non_none_values(service: MCPToolManageService) -> None: - # Arrange - provider = _provider_stub() - configuration = MCPConfiguration(timeout=99, sse_read_timeout=300) - - # Act - service._update_optional_fields(provider, configuration) - - # Assert - assert provider.timeout == 99 - assert provider.sse_read_timeout == 300 - - -def test_process_headers_should_return_none_when_empty_headers(service: MCPToolManageService) -> None: - # Arrange - provider = _provider_stub() - - # Act - result = service._process_headers({}, provider, "tenant-1") - - # Assert - assert result is None - - -def test_process_headers_should_merge_and_encrypt_headers( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "_merge_headers_with_masked", return_value={"x-api-key": "plain"}) - mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x-api-key":"enc"}') - - # Act - result = service._process_headers({"x-api-key": "*****"}, provider, "tenant-1") - - # Assert - assert result == '{"x-api-key":"enc"}' - - -def test_process_credentials_should_merge_and_encrypt_credentials( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - authentication = MCPAuthentication(client_id="masked-id", client_secret="masked-secret") - mocker.patch.object(service, "_merge_credentials_with_masked", return_value=("plain-id", "plain-secret")) - mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') - - # Act - result = service._process_credentials(authentication, provider, "tenant-1") - - # Assert - assert result == '{"client_information":{}}' - - -def test_merge_headers_with_masked_should_preserve_original_values_for_unchanged_masked_inputs( - service: MCPToolManageService, -) -> None: - # Arrange - provider = _provider_stub() - incoming_headers = {"x-api-key": "ke***ey", "new-header": "new-value", "dropped": "*****"} - - # Act - result = service._merge_headers_with_masked(incoming_headers, provider) - - # Assert - assert result["x-api-key"] == "key" - assert result["new-header"] == "new-value" - assert result["dropped"] == "*****" - - -def test_merge_credentials_with_masked_should_preserve_decrypted_values_when_masked_match( - service: MCPToolManageService, -) -> None: - # Arrange - provider = _provider_stub() - - # Act - client_id, client_secret = service._merge_credentials_with_masked("pl***id", "pl***et", provider) - - # Assert - assert client_id == "plain-id" - assert client_secret == "plain-secret" - - -def test_build_and_encrypt_credentials_should_encrypt_secret_when_client_secret_present( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch.object( - service, - "_encrypt_dict_fields", - return_value={ - "client_id": "id", - "client_name": "Dify", - "is_dynamic_registration": False, - "encrypted_client_secret": "enc-secret", - }, - ) - - # Act - result = service._build_and_encrypt_credentials("id", "secret", "tenant-1") - - # Assert - payload = json.loads(result) - assert payload["client_information"]["encrypted_client_secret"] == "enc-secret" - - -def test_build_and_encrypt_credentials_should_skip_secret_field_when_client_secret_is_none( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch.object( - service, - "_encrypt_dict_fields", - return_value={"client_id": "id", "client_name": "Dify", "is_dynamic_registration": False}, - ) - - # Act - result = service._build_and_encrypt_credentials("id", None, "tenant-1") - - # Assert - payload = json.loads(result) - assert "encrypted_client_secret" not in payload["client_information"] diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py deleted file mode 100644 index 9616d2f102..0000000000 --- a/api/tests/unit_tests/services/tools/test_tools_transform_service.py +++ /dev/null @@ -1,452 +0,0 @@ -from unittest.mock import Mock - -from core.tools.__base.tool import Tool -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType -from services.tools.tools_transform_service import ToolTransformService - - -class TestToolTransformService: - """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method""" - - def test_convert_tool_with_parameter_override(self): - """Test that runtime parameters correctly override base parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - # Create mock runtime parameters that override base parameters - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" # Different label to verify override - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.author == "test_author" - assert result.name == "test_tool" - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Find the overridden parameter - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" # Should be runtime version - - # Find the non-overridden parameter - original_param = next((p for p in result.parameters if p.name == "param2"), None) - assert original_param is not None - assert original_param.label == "Base Param 2" # Should be base version - - def test_convert_tool_with_additional_runtime_parameters(self): - """Test that additional runtime parameters are added to the final list""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters - one that overrides and one that's new - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "runtime_only" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Only Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 2 - - # Check that both parameters are present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "runtime_only" in param_names - - # Verify the overridden parameter has runtime version - overridden_param = next((p for p in result.parameters if p.name == "param1"), None) - assert overridden_param is not None - assert overridden_param.label == "Runtime Param 1" - - # Verify the new runtime parameter is included - new_param = next((p for p in result.parameters if p.name == "runtime_only"), None) - assert new_param is not None - assert new_param.label == "Runtime Only Param" - - def test_convert_tool_with_non_form_runtime_parameters(self): - """Test that non-FORM runtime parameters are not added as new parameters""" - # Create mock base parameters - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - # Create mock runtime parameters with different forms - runtime_param1 = Mock(spec=ToolParameter) - runtime_param1.name = "param1" - runtime_param1.form = ToolParameter.ToolParameterForm.FORM - runtime_param1.type = "string" - runtime_param1.label = "Runtime Param 1" - - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "llm_param" - runtime_param2.form = ToolParameter.ToolParameterForm.LLM - runtime_param2.type = "string" - runtime_param2.label = "LLM Param" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 1 # Only the FORM parameter should be present - - # Check that only the FORM parameter is present - param_names = [p.name for p in result.parameters] - assert "param1" in param_names - assert "llm_param" not in param_names - - def test_convert_tool_with_empty_parameters(self): - """Test conversion with empty base and runtime parameters""" - # Create mock tool with no parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_with_none_parameters(self): - """Test conversion when base parameters is None""" - # Create mock tool with None parameters - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = None - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 0 - - def test_convert_tool_parameter_order_preserved(self): - """Test that parameter order is preserved correctly""" - # Create mock base parameters in specific order - base_param1 = Mock(spec=ToolParameter) - base_param1.name = "param1" - base_param1.form = ToolParameter.ToolParameterForm.FORM - base_param1.type = "string" - base_param1.label = "Base Param 1" - - base_param2 = Mock(spec=ToolParameter) - base_param2.name = "param2" - base_param2.form = ToolParameter.ToolParameterForm.FORM - base_param2.type = "string" - base_param2.label = "Base Param 2" - - base_param3 = Mock(spec=ToolParameter) - base_param3.name = "param3" - base_param3.form = ToolParameter.ToolParameterForm.FORM - base_param3.type = "string" - base_param3.label = "Base Param 3" - - # Create runtime parameter that overrides middle parameter - runtime_param2 = Mock(spec=ToolParameter) - runtime_param2.name = "param2" - runtime_param2.form = ToolParameter.ToolParameterForm.FORM - runtime_param2.type = "string" - runtime_param2.label = "Runtime Param 2" - - # Create new runtime parameter - runtime_param4 = Mock(spec=ToolParameter) - runtime_param4.name = "param4" - runtime_param4.form = ToolParameter.ToolParameterForm.FORM - runtime_param4.type = "string" - runtime_param4.label = "Runtime Param 4" - - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.entity = Mock() - mock_tool.entity.parameters = [base_param1, base_param2, base_param3] - mock_tool.entity.identity = Mock() - mock_tool.entity.identity.author = "test_author" - mock_tool.entity.identity.name = "test_tool" - mock_tool.entity.identity.label = I18nObject(en_US="Test Tool") - mock_tool.entity.description = Mock() - mock_tool.entity.description.human = I18nObject(en_US="Test description") - mock_tool.entity.output_schema = {} - mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4] - - # Mock fork_tool_runtime to return the same tool - mock_tool.fork_tool_runtime.return_value = mock_tool - - # Call the method - result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None) - - # Verify the result - assert isinstance(result, ToolApiEntity) - assert result.parameters is not None - assert len(result.parameters) == 4 - - # Check that order is maintained: base parameters first, then new runtime parameters - param_names = [p.name for p in result.parameters] - assert param_names == ["param1", "param2", "param3", "param4"] - - # Verify that param2 was overridden with runtime version - param2 = result.parameters[1] - assert param2.name == "param2" - assert param2.label == "Runtime Param 2" - - -class TestWorkflowProviderToUserProvider: - """Test cases for ToolTransformService.workflow_provider_to_user_provider method""" - - def test_workflow_provider_to_user_provider_with_workflow_app_id(self): - """Test that workflow_provider_to_user_provider correctly sets workflow_app_id.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - workflow_app_id = "app_123" - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1", "label2"], - workflow_app_id=workflow_app_id, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "test_author" - assert result.name == "test_workflow_tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["label1", "label2"] - assert result.is_team_authorization is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] - - def test_workflow_provider_to_user_provider_without_workflow_app_id(self): - """Test that workflow_provider_to_user_provider works when workflow_app_id is not provided.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method without workflow_app_id - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["label1"], - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == ["label1"] - - def test_workflow_provider_to_user_provider_workflow_app_id_none(self): - """Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller - provider_id = "provider_123" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "test_author" - mock_controller.entity.identity.name = "test_workflow_tool" - mock_controller.entity.identity.description = I18nObject(en_US="Test description") - mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.icon_dark = None - mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool") - - # Call the method with explicit None values - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=None, - workflow_app_id=None, - ) - - # Verify the result - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.workflow_app_id is None - assert result.labels == [] - - def test_workflow_provider_to_user_provider_preserves_other_fields(self): - """Test that workflow_provider_to_user_provider preserves all other entity fields.""" - from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - - # Create mock workflow tool provider controller with various fields - workflow_app_id = "app_456" - provider_id = "provider_456" - mock_controller = Mock(spec=WorkflowToolProviderController) - mock_controller.provider_id = provider_id - mock_controller.entity = Mock() - mock_controller.entity.identity = Mock() - mock_controller.entity.identity.author = "another_author" - mock_controller.entity.identity.name = "another_workflow_tool" - mock_controller.entity.identity.description = I18nObject( - en_US="Another description", zh_Hans="Another description" - ) - mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"} - mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"} - mock_controller.entity.identity.label = I18nObject( - en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool" - ) - - # Call the method - result = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=mock_controller, - labels=["automation", "workflow"], - workflow_app_id=workflow_app_id, - ) - - # Verify all fields are preserved correctly - assert isinstance(result, ToolProviderApiEntity) - assert result.id == provider_id - assert result.author == "another_author" - assert result.name == "another_workflow_tool" - assert result.description.en_US == "Another description" - assert result.description.zh_Hans == "Another description" - assert result.icon == {"type": "emoji", "content": "⚙️"} - assert result.icon_dark == {"type": "emoji", "content": "🔧"} - assert result.label.en_US == "Another Workflow Tool" - assert result.label.zh_Hans == "Another Workflow Tool" - assert result.type == ToolProviderType.WORKFLOW - assert result.workflow_app_id == workflow_app_id - assert result.labels == ["automation", "workflow"] - assert result.masked_credentials == {} - assert result.is_team_authorization is True - assert result.allow_delete is True - assert result.plugin_id is None - assert result.plugin_unique_identifier is None - assert result.tools == [] diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index e9bcc89445..0000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,955 +0,0 @@ -""" -Unit tests for services.tools.workflow_tools_manage_service - -Covers WorkflowToolManageService: create, update, list, delete, get, list_single. -""" - -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service -from services.tools.workflow_tools_manage_service import WorkflowToolManageService - -# --------------------------------------------------------------------------- -# Shared helpers / fake infrastructure -# --------------------------------------------------------------------------- - - -class DummyWorkflow: - """Minimal in-memory Workflow substitute.""" - - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - """Chainable query object that always returns a fixed result.""" - - def __init__(self, result: object) -> None: - self._result = result - - def where(self, *args: object, **kwargs: object) -> "FakeQuery": - return self - - def first(self) -> object: - return self._result - - def delete(self) -> int: - return 1 - - -class DummySession: - """Minimal SQLAlchemy session substitute.""" - - def __init__(self) -> None: - self.added: list[WorkflowToolProvider] = [] - self.committed: bool = False - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: - return False - - def add(self, obj: WorkflowToolProvider) -> None: - self.added.append(obj) - - def begin(self) -> "DummySession": - return self - - def commit(self) -> None: - self.committed = True - - -def _build_parameters() -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM), - ] - - -def _build_fake_db( - *, - existing_tool: WorkflowToolProvider | None = None, - app: object | None = None, - tool_by_id: WorkflowToolProvider | None = None, -) -> tuple[MagicMock, DummySession]: - """ - Build a fake db object plus a DummySession for Session context-manager. - - query(WorkflowToolProvider) returns existing_tool on first call, - then tool_by_id on subsequent calls (or None if not provided). - query(App) returns app. - """ - call_counts: dict[str, int] = {"wftp": 0} - - def query(model: type) -> FakeQuery: - if model is WorkflowToolProvider: - call_counts["wftp"] += 1 - if call_counts["wftp"] == 1: - return FakeQuery(existing_tool) - return FakeQuery(tool_by_id) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query, commit=MagicMock()) - dummy_session = DummySession() - return fake_db, dummy_session - - -# --------------------------------------------------------------------------- -# TestCreateWorkflowTool -# --------------------------------------------------------------------------- - - -class TestCreateWorkflowTool: - """Tests for WorkflowToolManageService.create_workflow_tool.""" - - def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Human-input nodes must be rejected before any provider is created.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app)) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - # Act + Assert - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "🔧"}, - description="desc", - parameters=_build_parameters(), - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - - def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Existing provider with same name or app_id raises ValueError.""" - # Arrange - existing = MagicMock(spec=WorkflowToolProvider) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(existing)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-1", - name="dup", - label="Dup", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the referenced App does not exist.""" - # Arrange - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(None) # App returns None - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="missing-app", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the App has no attached Workflow.""" - # Arrange - app_no_workflow = SimpleNamespace(workflow=None) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app_no_workflow) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr( - workflow_tools_manage_service.WorkflowToolProviderController, - "from_db", - MagicMock(side_effect=RuntimeError("bad config")), - ) - - # Act + Assert - with pytest.raises(ValueError, match="bad config"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: provider is added to session and success dict is returned.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0") - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - - icon = {"type": "emoji", "emoji": "🔧"} - - # Act - result = WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) - - # Assert - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created: WorkflowToolProvider = dummy_session.added[0] - assert created.name == "tool_name" - assert created.label == "Tool" - assert created.icon == json.dumps(icon) - assert created.version == "2.0.0" - - def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Labels are forwarded to ToolLabelManager when provided.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - mock_label_mgr = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) - mock_to_ctrl = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl - ) - - # Act - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - labels=["tag1", "tag2"], - ) - - # Assert - mock_label_mgr.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestUpdateWorkflowTool -# --------------------------------------------------------------------------- - - -class TestUpdateWorkflowTool: - """Tests for WorkflowToolManageService.update_workflow_tool.""" - - def _make_provider(self) -> WorkflowToolProvider: - p = MagicMock(spec=WorkflowToolProvider) - p.app_id = "app-id" - p.tenant_id = "tenant-id" - return p - - def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None: - """If another tool with the given name already exists, raise ValueError.""" - # Arrange - existing = MagicMock(spec=WorkflowToolProvider) - - def query(m: type) -> FakeQuery: - return FakeQuery(existing) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="dup", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the workflow tool to update does not exist.""" - # Arrange - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - # 1st call: name uniqueness check → None (no duplicate) - # 2nd call: fetch tool by id → None (not found) - return FakeQuery(None) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="missing", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the tool's referenced App has been removed.""" - # Arrange - provider = self._make_provider() - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - # 1st: duplicate name check (None), 2nd: fetch provider - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(None) # App not found - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the App exists but has no Workflow.""" - # Arrange - provider = self._make_provider() - app_no_wf = SimpleNamespace(workflow=None) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app_no_wf) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Exceptions from from_db are re-raised as ValueError.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=MagicMock()), - ) - monkeypatch.setattr( - workflow_tools_manage_service.WorkflowToolProviderController, - "from_db", - MagicMock(side_effect=RuntimeError("from_db error")), - ) - - # Act + Assert - with pytest.raises(ValueError, match="from_db error"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: provider fields are updated and session committed.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0") - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - mock_commit = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=mock_commit), - ) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - - icon = {"type": "emoji", "emoji": "🛠"} - - # Act - result = WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="new_name", - label="New Label", - icon=icon, - description="new desc", - parameters=_build_parameters(), - ) - - # Assert - assert result == {"result": "success"} - mock_commit.assert_called_once() - assert provider.name == "new_name" - assert provider.label == "New Label" - assert provider.icon == json.dumps(icon) - assert provider.version == "3.0.0" - - def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Labels are forwarded to ToolLabelManager during update.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=MagicMock()), - ) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - mock_label_mgr = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock() - ) - - # Act - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - labels=["a"], - ) - - # Assert - mock_label_mgr.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestListTenantWorkflowTools -# --------------------------------------------------------------------------- - - -class TestListTenantWorkflowTools: - """Tests for WorkflowToolManageService.list_tenant_workflow_tools.""" - - def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """An empty database yields an empty result list.""" - # Arrange - fake_scalars = MagicMock() - fake_scalars.all.return_value = [] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - assert result == [] - - def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Providers that fail to load are logged and skipped.""" - # Arrange - good_provider = MagicMock(spec=WorkflowToolProvider) - good_provider.id = "good-id" - good_provider.app_id = "app-good" - bad_provider = MagicMock(spec=WorkflowToolProvider) - bad_provider.id = "bad-id" - bad_provider.app_id = "app-bad" - - fake_scalars = MagicMock() - fake_scalars.all.return_value = [good_provider, bad_provider] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - good_ctrl = MagicMock() - good_ctrl.provider_id = "good-id" - - def to_controller(provider: WorkflowToolProvider) -> MagicMock: - if provider is bad_provider: - raise RuntimeError("broken provider") - return good_ctrl - - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller - ) - mock_get_labels = MagicMock(return_value={}) - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels) - mock_to_user = MagicMock() - mock_to_user.return_value.tools = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user - ) - monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) - mock_get_tools = MagicMock(return_value=[MagicMock()]) - good_ctrl.get_tools = mock_get_tools - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() - ) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - only good provider contributed - assert len(result) == 1 - - def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None: - """All successfully loaded providers appear in the result.""" - # Arrange - provider = MagicMock(spec=WorkflowToolProvider) - provider.id = "p-1" - provider.app_id = "app-1" - - fake_scalars = MagicMock() - fake_scalars.all.return_value = [provider] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - ctrl = MagicMock() - ctrl.provider_id = "p-1" - ctrl.get_tools.return_value = [MagicMock()] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []}) - ) - user_provider = MagicMock() - user_provider.tools = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_user_provider", - MagicMock(return_value=user_provider), - ) - monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() - ) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - assert len(result) == 1 - assert result[0] is user_provider - - -# --------------------------------------------------------------------------- -# TestDeleteWorkflowTool -# --------------------------------------------------------------------------- - - -class TestDeleteWorkflowTool: - """Tests for WorkflowToolManageService.delete_workflow_tool.""" - - def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: - """delete_workflow_tool queries, deletes, commits, and returns success.""" - # Arrange - mock_query = MagicMock() - mock_query.where.return_value.delete.return_value = 1 - mock_commit = MagicMock() - fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - # Act - result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1") - - # Assert - assert result == {"result": "success"} - mock_commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestGetWorkflowToolByToolId / ByAppId -# --------------------------------------------------------------------------- - - -class TestGetWorkflowToolByToolIdAndAppId: - """Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id.""" - - def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Raises ValueError when no WorkflowToolProvider found by tool id.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing") - - def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Raises ValueError when no WorkflowToolProvider found by app id.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app") - - -# --------------------------------------------------------------------------- -# TestGetWorkflowTool (private _get_workflow_tool) -# --------------------------------------------------------------------------- - - -class TestGetWorkflowTool: - """Tests for the internal _get_workflow_tool helper.""" - - def test_should_raise_when_db_tool_none(self) -> None: - """_get_workflow_tool raises ValueError when db_tool is None.""" - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService._get_workflow_tool("t", None) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the corresponding App row is missing.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when App has no attached Workflow.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - app = SimpleNamespace(workflow=None) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the controller returns no WorkflowTool instances.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - db_tool.id = "tool-1" - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - ctrl = MagicMock() - ctrl.get_tools.return_value = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: returns a dict with name, label, icon, synced, etc.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - db_tool.id = "tool-1" - db_tool.name = "my_tool" - db_tool.label = "My Tool" - db_tool.icon = json.dumps({"emoji": "🔧"}) - db_tool.description = "some desc" - db_tool.privacy_policy = "" - db_tool.version = "1.0" - db_tool.parameter_configurations = [] - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0") - app = SimpleNamespace(workflow=workflow) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - - workflow_tool = MagicMock() - workflow_tool.entity.output_schema = {"type": "object"} - ctrl = MagicMock() - ctrl.get_tools.return_value = [workflow_tool] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - mock_convert = MagicMock(return_value={"tool": "api_entity"}) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) - ) - - # Act - result = WorkflowToolManageService._get_workflow_tool("t", db_tool) - - # Assert - assert result["name"] == "my_tool" - assert result["label"] == "My Tool" - assert result["synced"] is True - assert "icon" in result - assert "output_schema" in result - - -# --------------------------------------------------------------------------- -# TestListSingleWorkflowTools -# --------------------------------------------------------------------------- - - -class TestListSingleWorkflowTools: - """Tests for WorkflowToolManageService.list_single_workflow_tools.""" - - def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the specified tool does not exist in DB.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the controller yields no tools for the provider.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.id = "tool-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(db_tool)), - ) - ctrl = MagicMock() - ctrl.get_tools.return_value = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: returns list with one ToolApiEntity.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.id = "tool-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(db_tool)), - ) - workflow_tool = MagicMock() - ctrl = MagicMock() - ctrl.get_tools.return_value = [workflow_tool] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - api_entity = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "convert_tool_entity_to_api_entity", - MagicMock(return_value=api_entity), - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) - ) - - # Act - result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - # Assert - assert result == [api_entity] diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index c99275c6b2..ee9ba1c6d6 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,6 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -151,8 +152,8 @@ class VectorServiceTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", - indexing_technique: str = "high_quality", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", index_struct_dict: dict | None = None, @@ -493,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -505,7 +506,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model") + VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_called_once() @@ -521,7 +522,7 @@ class TestVectorService: assert call_args[1]["keywords_list"] == keywords_list @patch("services.vector_service.VectorService.generate_child_chunks") - @patch("services.vector_service.ModelManager") + @patch("services.vector_service.ModelManager.for_tenant") @patch("services.vector_service.db") def test_create_segments_vector_parent_child_indexing( self, mock_db, mock_model_manager, mock_generate_child_chunks @@ -534,7 +535,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -567,7 +568,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -590,7 +591,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -615,7 +616,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="economy" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -649,7 +650,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_not_called() @@ -668,7 +669,7 @@ class TestVectorService: store when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -694,7 +695,7 @@ class TestVectorService: index when using economy indexing with keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -730,7 +731,7 @@ class TestVectorService: index when using economy indexing without keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -894,7 +895,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -922,7 +923,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -950,7 +951,7 @@ class TestVectorService: when there are new chunks, updated chunks, and deleted chunks. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") @@ -992,7 +993,7 @@ class TestVectorService: add_texts is called, not delete_by_ids. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1018,7 +1019,7 @@ class TestVectorService: delete_by_ids is called, not add_texts. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1044,7 +1045,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1074,7 +1075,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1098,7 +1099,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1754,7 +1755,7 @@ class TestVector: # ======================================================================== @patch("core.rag.datasource.vdb.vector_factory.CacheEmbedding") - @patch("core.rag.datasource.vdb.vector_factory.ModelManager") + @patch("core.rag.datasource.vdb.vector_factory.ModelManager.for_tenant") @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") def test_vector_get_embeddings(self, mock_init_vector, mock_model_manager, mock_cache_embedding): """ diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index f3391d6380..2db83576b0 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -6,8 +6,11 @@ from unittest.mock import Mock, patch import pytest from sqlalchemy import Engine -from dify_graph.variables.segments import ObjectSegment, StringSegment -from dify_graph.variables.types import SegmentType +from core.workflow.file_reference import build_file_reference +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader @@ -54,25 +57,18 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_content.encode() - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_variable" - mock_variable.value = StringSegment(value=test_content) - mock_segment_to_variable.return_value = mock_variable + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + # Verify results + assert selector_tuple == ("test-node-id", "test_variable") + assert variable.id == "draft-var-id" + assert variable.name == "test_variable" + assert variable.description == "test description" + assert variable.value == test_content - # Verify results - assert selector_tuple == ("test-node-id", "test_variable") - assert variable.id == "draft-var-id" - assert variable.name == "test_variable" - assert variable.description == "test description" - assert variable.value == test_content - - # Verify storage was called correctly - mock_storage.load.assert_called_once_with("storage/key/test.txt") + # Verify storage was called correctly + mock_storage.load.assert_called_once_with("storage/key/test.txt") def test_load_offloaded_variable_object_type_unit(self, draft_var_loader): """Test _load_offloaded_variable with object type - isolated unit test.""" @@ -97,31 +93,22 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + mock_segment = ObjectSegment(value=test_object) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - mock_segment = ObjectSegment(value=test_object) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_object" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_object") + assert variable.id == "draft-var-id" + assert variable.name == "test_object" + assert variable.description == "test description" + assert variable.value == test_object - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - - # Verify results - assert selector_tuple == ("test-node-id", "test_object") - assert variable.id == "draft-var-id" - assert variable.name == "test_object" - assert variable.description == "test description" - assert variable.value == test_object - - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test.json") - mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.OBJECT, test_object) def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader): """Test that assertion error is raised when variable_file is None.""" @@ -176,32 +163,23 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + from graphon.variables.segments import FloatSegment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from dify_graph.variables.segments import FloatSegment + mock_segment = FloatSegment(value=test_number) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - mock_segment = FloatSegment(value=test_number) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_number" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_number") + assert variable.id == "draft-var-id" + assert variable.name == "test_number" + assert variable.description == "test number description" - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - - # Verify results - assert selector_tuple == ("test-node-id", "test_number") - assert variable.id == "draft-var-id" - assert variable.name == "test_number" - assert variable.description == "test number description" - - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test_number.json") - mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_number.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.NUMBER, test_number) def test_load_offloaded_variable_array_type_unit(self, draft_var_loader): """Test _load_offloaded_variable with array type - isolated unit test.""" @@ -226,32 +204,83 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + from graphon.variables.segments import ArrayAnySegment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from dify_graph.variables.segments import ArrayAnySegment + mock_segment = ArrayAnySegment(value=test_array) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - mock_segment = ArrayAnySegment(value=test_array) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_array" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_array") + assert variable.id == "draft-var-id" + assert variable.name == "test_array" + assert variable.description == "test array description" - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_array.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) - # Verify results - assert selector_tuple == ("test-node-id", "test_array") - assert variable.id == "draft-var-id" - assert variable.name == "test_array" - assert variable.description == "test array description" + def test_load_offloaded_variable_file_type_rebuilds_storage_backed_payload(self, draft_var_loader): + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test_file.json" - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test_array.json") - mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.FILE + variable_file.upload_file = upload_file + + draft_var = WorkflowDraftVariable() + draft_var.id = "draft-var-id" + draft_var.app_id = "app-1" + draft_var.node_id = "test-node-id" + draft_var.name = "test_file" + draft_var.description = "test file description" + draft_var._set_selector(["test-node-id", "test_file"]) + draft_var.variable_file = variable_file + + persisted_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-1", storage_key="legacy-storage-key"), + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + ) + rebuilt_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-1"), + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + storage_key="canonical-storage-key", + ) + + raw_file = { + **persisted_file.model_dump(mode="json"), + "tenant_id": "legacy-tenant", + } + + with ( + patch("services.workflow_draft_variable_service.storage") as mock_storage, + patch("models.workflow._resolve_workflow_app_tenant_id", return_value="tenant-1"), + patch("models.workflow.build_file_from_stored_mapping", return_value=rebuilt_file) as rebuild_file, + ): + mock_storage.load.return_value = json.dumps(raw_file).encode() + + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + assert selector_tuple == ("test-node-id", "test_file") + assert variable.id == "draft-var-id" + assert variable.name == "test_file" + assert variable.description == "test file description" + assert variable.value == rebuilt_file + rebuild_file.assert_called_once_with(file_mapping=raw_file, tenant_id="tenant-1") def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader): """Test load_variables method with mix of regular and offloaded variables.""" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py deleted file mode 100644 index a847c2b4d1..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ /dev/null @@ -1,431 +0,0 @@ -# test for api/services/workflow/workflow_converter.py -import json -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import ( - AdvancedChatMessageEntity, - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.helper import encrypter -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import AppMode -from services.workflow.workflow_converter import WorkflowConverter - - -@pytest.fixture -def default_variables(): - value = [ - VariableEntity( - variable="text_input", - label="text-input", - type=VariableEntityType.TEXT_INPUT, - ), - VariableEntity( - variable="paragraph", - label="paragraph", - type=VariableEntityType.PARAGRAPH, - ), - VariableEntity( - variable="select", - label="select", - type=VariableEntityType.SELECT, - ), - ] - return value - - -def test__convert_to_start_node(default_variables): - # act - result = WorkflowConverter()._convert_to_start_node(default_variables) - - # assert - assert isinstance(result["data"]["variables"][0]["type"], str) - assert result["data"]["variables"][0]["type"] == "text-input" - assert result["data"]["variables"][0]["variable"] == "text_input" - assert result["data"]["variables"][1]["variable"] == "paragraph" - assert result["data"]["variables"][2]["variable"] == "select" - - -def test__convert_to_http_request_node_for_chatbot(default_variables): - """ - Test convert to http request nodes for chatbot - :return: - """ - app_model = MagicMock() - app_model.id = "app_id" - app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT - - api_based_extension_id = "api_based_extension_id" - mock_api_based_extension = APIBasedExtension( - tenant_id="tenant_id", - name="api-1", - api_key="encrypted_api_key", - api_endpoint="https://dify.ai", - ) - - mock_api_based_extension.id = api_based_extension_id - workflow_converter = WorkflowConverter() - workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) - - encrypter.decrypt_token = MagicMock(return_value="api_key") - - external_data_variables = [ - ExternalDataVariableEntity( - variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} - ) - ] - - nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, variables=default_variables, external_data_variables=external_data_variables - ) - - assert len(nodes) == 2 - assert nodes[0]["data"]["type"] == "http-request" - - http_request_node = nodes[0] - - assert http_request_node["data"]["method"] == "post" - assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint - assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} - assert http_request_node["data"]["body"]["type"] == "json" - - body_data = http_request_node["data"]["body"]["data"] - - assert body_data - - body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY - - body_params = body_data_json["params"] - assert body_params["app_id"] == app_model.id - assert body_params["tool_variable"] == external_data_variables[0].variable - assert len(body_params["inputs"]) == 3 - assert body_params["query"] == "{{#sys.query#}}" # for chatbot - - code_node = nodes[1] - assert code_node["data"]["type"] == "code" - - -def test__convert_to_http_request_node_for_workflow_app(default_variables): - """ - Test convert to http request nodes for workflow app - :return: - """ - app_model = MagicMock() - app_model.id = "app_id" - app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW - - api_based_extension_id = "api_based_extension_id" - mock_api_based_extension = APIBasedExtension( - tenant_id="tenant_id", - name="api-1", - api_key="encrypted_api_key", - api_endpoint="https://dify.ai", - ) - mock_api_based_extension.id = api_based_extension_id - - workflow_converter = WorkflowConverter() - workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) - - encrypter.decrypt_token = MagicMock(return_value="api_key") - - external_data_variables = [ - ExternalDataVariableEntity( - variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} - ) - ] - - nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, variables=default_variables, external_data_variables=external_data_variables - ) - - assert len(nodes) == 2 - assert nodes[0]["data"]["type"] == "http-request" - - http_request_node = nodes[0] - - assert http_request_node["data"]["method"] == "post" - assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint - assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} - assert http_request_node["data"]["body"]["type"] == "json" - - body_data = http_request_node["data"]["body"]["data"] - - assert body_data - - body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY - - body_params = body_data_json["params"] - assert body_params["app_id"] == app_model.id - assert body_params["tool_variable"] == external_data_variables[0].variable - assert len(body_params["inputs"]) == 3 - assert body_params["query"] == "" - - code_node = nodes[1] - assert code_node["data"]["type"] == "code" - - -def test__convert_to_knowledge_retrieval_node_for_chatbot(): - new_app_mode = AppMode.ADVANCED_CHAT - - dataset_config = DatasetEntity( - dataset_ids=["dataset_id_1", "dataset_id_2"], - retrieve_config=DatasetRetrieveConfigEntity( - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, - top_k=5, - score_threshold=0.8, - reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, - reranking_enabled=True, - ), - ) - - model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) - - node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config - ) - assert node is not None - - assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["sys", "query"] - assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value - assert node["data"]["multiple_retrieval_config"] == { - "top_k": dataset_config.retrieve_config.top_k, - "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model, - } - - -def test__convert_to_knowledge_retrieval_node_for_workflow_app(): - new_app_mode = AppMode.WORKFLOW - - dataset_config = DatasetEntity( - dataset_ids=["dataset_id_1", "dataset_id_2"], - retrieve_config=DatasetRetrieveConfigEntity( - query_variable="query", - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, - top_k=5, - score_threshold=0.8, - reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, - reranking_enabled=True, - ), - ) - - model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) - - node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config - ) - assert node is not None - - assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] - assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value - assert node["data"]["multiple_retrieval_config"] == { - "top_k": dataset_config.retrieve_config.top_k, - "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model, - } - - -def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-4" - model_mode = LLMMode.CHAT - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - template = prompt_template.simple_prompt_template - assert template is not None - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" - assert llm_node["data"]["context"]["enabled"] is False - - -def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-3.5-turbo-instruct" - model_mode = LLMMode.COMPLETION - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - template = prompt_template.simple_prompt_template - assert template is not None - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"]["text"] == template + "\n" - assert llm_node["data"]["context"]["enabled"] is False - - -def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-4" - model_mode = LLMMode.CHAT - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( - messages=[ - AdvancedChatMessageEntity( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM, - ), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ] - ), - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - assert isinstance(llm_node["data"]["prompt_template"], list) - assert prompt_template.advanced_chat_prompt_template is not None - assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) - template = prompt_template.advanced_chat_prompt_template.messages[0].text - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"][0]["text"] == template - - -def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-3.5-turbo-instruct" - model_mode = LLMMode.COMPLETION - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( - prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), - ), - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - assert isinstance(llm_node["data"]["prompt_template"], dict) - assert prompt_template.advanced_completion_prompt_template is not None - template = prompt_template.advanced_completion_prompt_template.prompt - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py deleted file mode 100644 index dfe325648d..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ /dev/null @@ -1,127 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import Session - -from models.model import App -from models.workflow import Workflow -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService - - -@pytest.fixture -def workflow_setup(): - mock_session_maker = MagicMock() - workflow_service = WorkflowService(mock_session_maker) - session = MagicMock(spec=Session) - tenant_id = "test-tenant-id" - workflow_id = "test-workflow-id" - - # Mock workflow - workflow = MagicMock(spec=Workflow) - workflow.id = workflow_id - workflow.tenant_id = tenant_id - workflow.version = "1.0" # Not a draft - workflow.tool_published = False # Not published as a tool by default - - # Mock app - app = MagicMock(spec=App) - app.id = "test-app-id" - app.name = "Test App" - app.workflow_id = None # Not used by an app by default - - return { - "workflow_service": workflow_service, - "session": session, - "tenant_id": tenant_id, - "workflow_id": workflow_id, - "workflow": workflow, - "app": app, - } - - -def test_delete_workflow_success(workflow_setup): - # Setup mocks - - # Mock the tool provider query to return None (not published as a tool) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = None - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method - result = workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - assert result is True - workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"]) - - -def test_delete_workflow_draft_error(workflow_setup): - # Setup mocks - workflow_setup["workflow"].version = "draft" - workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"]) - - # Call the method and verify exception - with pytest.raises(DraftWorkflowDeletionError): - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_in_use_by_app_error(workflow_setup): - # Setup mocks - workflow_setup["app"].workflow_id = workflow_setup["workflow_id"] - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], workflow_setup["app"]] - ) # Return workflow first, then app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message contains app name - assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_published_as_tool_error(workflow_setup): - # Setup mocks - from models.tools import WorkflowToolProvider - - # Mock the tool provider query - mock_tool_provider = MagicMock(spec=WorkflowToolProvider) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message - assert "Cannot delete workflow that is published as a tool" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 0c2be9c79f..6200c9f859 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -7,10 +7,17 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from graphon.enums import BuiltinNodeTypes +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType @@ -86,6 +93,20 @@ class TestDraftVariableSaver: expected_node_id=_NODE_ID, expected_name="start_input", ), + TestCase( + name="name with `env.` prefix should return the environment node_id", + input_node_id=_NODE_ID, + input_name="env.API_KEY", + expected_node_id=ENVIRONMENT_VARIABLE_NODE_ID, + expected_name="API_KEY", + ), + TestCase( + name="name with `conversation.` prefix should return the conversation node_id", + input_node_id=_NODE_ID, + input_name="conversation.session_id", + expected_node_id=CONVERSATION_VARIABLE_NODE_ID, + expected_name="session_id", + ), TestCase( name="dummy_variable should return the original input node_id", input_node_id=_NODE_ID, @@ -112,6 +133,47 @@ class TestDraftVariableSaver: assert node_id == c.expected_node_id, fail_msg assert name == c.expected_name, fail_msg + def test_build_variables_from_start_mapping_rebuilds_system_files(self): + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = str(uuid.uuid4()) + saver = DraftVariableSaver( + session=mock_session, + app_id=self._get_test_app_id(), + node_id="start", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-1", + user=mock_user, + ) + rebuilt_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference="upload-1", + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + storage_key="canonical-storage-key", + ) + raw_file = { + **rebuilt_file.model_dump(mode="json"), + "tenant_id": "legacy-tenant", + } + + with ( + patch.object(saver, "_resolve_app_tenant_id", return_value="tenant-1"), + patch( + "services.workflow_draft_variable_service.build_file_from_stored_mapping", + return_value=rebuilt_file, + ) as rebuild_file, + ): + draft_vars = saver._build_variables_from_start_mapping({"sys.files": [raw_file]}) + + sys_var = draft_vars[0] + assert sys_var.get_value().value[0] == rebuilt_file + rebuild_file.assert_called_once_with(file_mapping=raw_file, tenant_id="tenant-1") + @pytest.fixture def mock_session(self): """Mock SQLAlchemy session.""" @@ -218,6 +280,46 @@ class TestDraftVariableSaver: str(SystemVariableKey.WORKFLOW_EXECUTION_ID), } + @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True) + def test_start_node_save_normalizes_reserved_prefix_outputs(self, mock_batch_upsert): + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = "test-user-id" + mock_user.tenant_id = "test-tenant-id" + + saver = DraftVariableSaver( + session=mock_session, + app_id="test-app-id", + node_id="start-node-id", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-id", + user=mock_user, + ) + + saver.save( + outputs={ + "env.API_KEY": "secret", + "conversation.session_id": "conversation-1", + "sys.workflow_run_id": "run-id-123", + } + ) + + mock_batch_upsert.assert_called_once() + draft_vars = mock_batch_upsert.call_args[0][1] + + assert len(draft_vars) == 3 + + env_var = next(v for v in draft_vars if v.node_id == ENVIRONMENT_VARIABLE_NODE_ID) + assert env_var.name == "API_KEY" + assert env_var.editable is False + + conversation_var = next(v for v in draft_vars if v.node_id == CONVERSATION_VARIABLE_NODE_ID) + assert conversation_var.name == "session_id" + assert conversation_var.node_execution_id is None + + sys_var = next(v for v in draft_vars if v.node_id == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.name == str(SystemVariableKey.WORKFLOW_EXECUTION_ID) + class TestWorkflowDraftVariableService: def _get_test_app_id(self): diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 6c1adba2b8..ce66b78b64 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -12,9 +12,9 @@ import pytest from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index c890ab6a65..d7192994b2 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -5,16 +5,16 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.orm import sessionmaker -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -23,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: +def _build_node_config(delivery_methods: list[EmailDeliveryMethod], *, legacy: bool = False) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -31,6 +31,14 @@ def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfi inputs=[], user_actions=[], ).model_dump(mode="json") + if legacy: + for delivery_method in node_data["delivery_methods"]: + recipients = delivery_method.get("config", {}).get("recipients", {}) + if "include_bound_group" in recipients: + recipients["whole_workspace"] = recipients.pop("include_bound_group") + for recipient in recipients.get("items", []): + if "reference_id" in recipient: + recipient["user_id"] = recipient.pop("reference_id") node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) @@ -41,7 +49,7 @@ def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailD enabled=enabled, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="tester@example.com")], ), subject="Test subject", @@ -69,7 +77,7 @@ def test_human_input_delivery_requires_draft_workflow(): def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=False) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -105,7 +113,7 @@ def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyP def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -144,7 +152,7 @@ def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.Mon def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True, debug_mode=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -178,8 +186,8 @@ def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytes sent_method = test_service_instance.send_test.call_args.kwargs["method"] assert isinstance(sent_method, EmailDeliveryMethod) assert sent_method.config.debug_mode is True - assert sent_method.config.recipients.whole_workspace is False + assert sent_method.config.recipients.include_bound_group is False assert len(sent_method.config.recipients.items) == 1 recipient = sent_method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == account.id + assert recipient.reference_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py deleted file mode 100644 index 79bf5e94c2..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from repositories.sqlalchemy_api_workflow_node_execution_repository import ( - DifyAPISQLAlchemyWorkflowNodeExecutionRepository, -) - - -class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: - @pytest.fixture - def repository(self): - mock_session_maker = MagicMock() - return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) - - def test_repository_implements_protocol(self, repository): - """Test that the repository implements the required protocol methods.""" - # Verify all protocol methods are implemented - assert hasattr(repository, "get_node_last_execution") - assert hasattr(repository, "get_executions_by_workflow_run") - assert hasattr(repository, "get_execution_by_id") - - # Verify methods are callable - assert callable(repository.get_node_last_execution) - assert callable(repository.get_executions_by_workflow_run) - assert callable(repository.get_execution_by_id) - assert callable(repository.delete_expired_executions) - assert callable(repository.delete_executions_by_app) - assert callable(repository.get_expired_executions_batch) - assert callable(repository.delete_executions_by_ids) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 538c1b3595..6b04a1bc09 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock import pytest -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import FormInputType +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow from services import workflow_service as workflow_service_module diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 74ba7f9c34..936a10d6c5 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -183,10 +184,10 @@ class TestErrorHandling: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -228,10 +229,10 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=pipeline_id, ) @@ -264,10 +265,10 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=None, ) @@ -320,10 +321,10 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -365,10 +366,10 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert - storage delete was attempted @@ -407,10 +408,10 @@ class TestEdgeCases: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -444,7 +445,7 @@ class TestIndexProcessorParameters: - Dataset object with correct attributes is passed """ # Arrange - indexing_technique = "high_quality" + indexing_technique = IndexTechniqueType.HIGH_QUALITY index_struct = '{"type": "paragraph"}' # Act @@ -454,7 +455,7 @@ class TestIndexProcessorParameters: indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 8a721124d6..0b189ebae2 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -58,6 +59,11 @@ def mock_redis(): # Redis is already mocked globally in conftest.py # Reset it for each test redis_client.reset_mock() + redis_client.get.reset_mock() + redis_client.setex.reset_mock() + redis_client.delete.reset_mock() + redis_client.lpush.reset_mock() + redis_client.rpop.reset_mock() redis_client.get.return_value = None redis_client.setex.return_value = True redis_client.delete.return_value = True @@ -203,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id): dataset = Mock(spec=Dataset) dataset.id = dataset_id dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -222,7 +228,7 @@ def mock_documents(document_ids, dataset_id): doc.stopped_at = None doc.processing_started_at = None # optional attribute used in some code paths - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX documents.append(doc) return documents diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 3668416e36..f49f4535af 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, document.tenant_id = str(uuid.uuid4()) document.data_source_type = "notion_import" document.indexing_status = "completed" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX document.data_source_info_dict = { "notion_workspace_id": notion_workspace_id, "notion_page_id": notion_page_id, diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 0000000000..b48c69a146 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index bd0182a402..591da56f49 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py index a223f0119e..8cac696d98 100644 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -11,11 +11,11 @@ # import pytest -# from dify_graph.entities.workflow_node_execution import ( +# from graphon.entities.workflow_node_execution import ( # WorkflowNodeExecution, # WorkflowNodeExecutionStatus, # ) -# from dify_graph.enums import BuiltinNodeTypes +# from graphon.enums import BuiltinNodeTypes # from libs.datetime_utils import naive_utc_now # from models import WorkflowNodeExecutionModel # from models.enums import ExecutionOffLoadType diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index fa9c6af287..f31bf80046 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -17,7 +17,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 7ec1343f98..c166a946d9 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -5,7 +5,7 @@ import pytest from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, @@ -13,13 +13,13 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMResultWithStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: @@ -336,7 +336,6 @@ def test_structured_output_parser(): json_schema=case["json_schema"], stream=case["stream"], model_parameters={"temperature": 0.7, "max_tokens": 100}, - user="test_user", ) if case["expected_result_type"] == "generator": @@ -367,7 +366,7 @@ def test_structured_output_parser(): call_args = model_instance.invoke_llm.call_args assert call_args.kwargs["stream"] == case["stream"] - assert call_args.kwargs["user"] == "test_user" + assert "user" not in call_args.kwargs assert "temperature" in call_args.kwargs["model_parameters"] assert "max_tokens" in call_args.kwargs["model_parameters"] diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index 1f0bf8ef37..a29df0bb6b 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -2,7 +2,10 @@ from collections.abc import Mapping from typing import Any from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from dify_graph.entities.graph_init_params import GraphInitParams +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from graphon.entities.graph_init_params import GraphInitParams +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable def build_test_run_context( @@ -51,3 +54,16 @@ def build_test_graph_init_params( ), call_depth=call_depth, ) + + +def build_test_variable_pool( + *, + variables: list[Variable] | tuple[Variable, ...] = (), + node_id: str | None = None, + inputs: Mapping[str, Any] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + if node_id is not None and inputs is not None: + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=inputs) + return variable_pool diff --git a/api/uv.lock b/api/uv.lock index ebfc6678fe..ed2b76ac3c 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -169,12 +169,6 @@ version = "1.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a0/87/1d7019d23891897cb076b2f7e3c81ab3c2ba91de3bb067196f675d60d34c/alibabacloud-credentials-api-1.0.0.tar.gz", hash = "sha256:8c340038d904f0218d7214a8f4088c31912bfcf279af2cbc7d9be4897a97dd2f", size = 2330, upload-time = "2025-01-13T05:53:04.931Z" } -[[package]] -name = "alibabacloud-endpoint-util" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/92/7d/8cc92a95c920e344835b005af6ea45a0db98763ad6ad19299d26892e6c8d/alibabacloud_endpoint_util-0.0.4.tar.gz", hash = "sha256:a593eb8ddd8168d5dc2216cd33111b144f9189fcd6e9ca20e48f358a739bbf90", size = 2813, upload-time = "2025-06-12T07:20:52.572Z" } - [[package]] name = "alibabacloud-gateway-spi" version = "0.0.3" @@ -186,69 +180,17 @@ sdist = { url = "https://files.pythonhosted.org/packages/ab/98/d7111245f17935bf7 [[package]] name = "alibabacloud-gpdb20160503" -version = "3.8.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-endpoint-util" }, - { name = "alibabacloud-openapi-util" }, - { name = "alibabacloud-openplatform20191219" }, - { name = "alibabacloud-oss-sdk" }, - { name = "alibabacloud-oss-util" }, - { name = "alibabacloud-tea-fileform" }, - { name = "alibabacloud-tea-openapi" }, - { name = "alibabacloud-tea-util" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/15/6a/cc72e744e95c8f37fa6a84e66ae0b9b57a13ee97a0ef03d94c7127c31d75/alibabacloud_gpdb20160503-3.8.3.tar.gz", hash = "sha256:4dfcc0d9cff5a921d529d76f4bf97e2ceb9dc2fa53f00ab055f08509423d8e30", size = 155092, upload-time = "2024-07-18T17:09:42.438Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/36/bce41704b3bf59d607590ec73a42a254c5dea27c0f707aee11d20512a200/alibabacloud_gpdb20160503-3.8.3-py3-none-any.whl", hash = "sha256:06e1c46ce5e4e9d1bcae76e76e51034196c625799d06b2efec8d46a7df323fe8", size = 156097, upload-time = "2024-07-18T17:09:40.414Z" }, -] - -[[package]] -name = "alibabacloud-openapi-util" -version = "0.2.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea-util" }, - { name = "cryptography" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f6/50/5f41ab550d7874c623f6e992758429802c4b52a6804db437017e5387de33/alibabacloud_openapi_util-0.2.2.tar.gz", hash = "sha256:ebbc3906f554cb4bf8f513e43e8a33e8b6a3d4a0ef13617a0e14c3dda8ef52a8", size = 7201, upload-time = "2023-10-23T07:44:18.523Z" } - -[[package]] -name = "alibabacloud-openplatform20191219" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-endpoint-util" }, - { name = "alibabacloud-openapi-util" }, - { name = "alibabacloud-tea-openapi" }, - { name = "alibabacloud-tea-util" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4f/bf/f7fa2f3657ed352870f442434cb2f27b7f70dcd52a544a1f3998eeaf6d71/alibabacloud_openplatform20191219-2.0.0.tar.gz", hash = "sha256:e67f4c337b7542538746592c6a474bd4ae3a9edccdf62e11a32ca61fad3c9020", size = 5038, upload-time = "2022-09-21T06:16:10.683Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/e5/18c75213551eeca9db1f6b41ddcc0bd87b5b6508c75a67f05cd8671847b4/alibabacloud_openplatform20191219-2.0.0-py3-none-any.whl", hash = "sha256:873821c45bca72a6c6ec7a906c9cb21554c122e88893bbac3986934dab30dd36", size = 5204, upload-time = "2022-09-21T06:16:07.844Z" }, -] - -[[package]] -name = "alibabacloud-oss-sdk" -version = "0.1.1" +version = "5.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, - { name = "alibabacloud-oss-util" }, - { name = "alibabacloud-tea-fileform" }, - { name = "alibabacloud-tea-util" }, - { name = "alibabacloud-tea-xml" }, + { name = "alibabacloud-tea-openapi" }, + { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/d1/f442dd026908fcf55340ca694bb1d027aa91e119e76ae2fbea62f2bde4f4/alibabacloud_oss_sdk-0.1.1.tar.gz", hash = "sha256:f51a368020d0964fcc0978f96736006f49f5ab6a4a4bf4f0b8549e2c659e7358", size = 46434, upload-time = "2025-04-22T12:40:41.717Z" } - -[[package]] -name = "alibabacloud-oss-util" -version = "0.0.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, +sdist = { url = "https://files.pythonhosted.org/packages/b3/36/69333c7fb7fb5267f338371b14fdd8dbdd503717c97bbc7a6419d155ab4c/alibabacloud_gpdb20160503-5.1.0.tar.gz", hash = "sha256:086ec6d5e39b64f54d0e44bb3fd4fde1a4822a53eb9f6ff7464dff7d19b07b63", size = 295641, upload-time = "2026-03-19T10:09:02.444Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/7f/a91a2f9ad97c92fa9a6981587ea0ff789240cea05b17b17b7c244e5bac64/alibabacloud_gpdb20160503-5.1.0-py3-none-any.whl", hash = "sha256:580e4579285a54c7f04570782e0f60423a1997568684187fe88e4110acfb640e", size = 848784, upload-time = "2026-03-19T10:09:00.72Z" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/7c/d7e812b9968247a302573daebcfef95d0f9a718f7b4bfcca8d3d83e266be/alibabacloud_oss_util-0.0.6.tar.gz", hash = "sha256:d3ecec36632434bd509a113e8cf327dc23e830ac8d9dd6949926f4e334c8b5d6", size = 10008, upload-time = "2021-04-28T09:25:04.056Z" } [[package]] name = "alibabacloud-tea" @@ -260,15 +202,6 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/9a/7d/b22cb9a0d4f396ee0f3f9d7f26b76b9ed93d4101add7867a2c87ed2534f5/alibabacloud-tea-0.4.3.tar.gz", hash = "sha256:ec8053d0aa8d43ebe1deb632d5c5404339b39ec9a18a0707d57765838418504a", size = 8785, upload-time = "2025-03-24T07:34:42.958Z" } -[[package]] -name = "alibabacloud-tea-fileform" -version = "0.0.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/22/8a/ef8ddf5ee0350984cad2749414b420369fe943e15e6d96b79be45367630e/alibabacloud_tea_fileform-0.0.5.tar.gz", hash = "sha256:fd00a8c9d85e785a7655059e9651f9e91784678881831f60589172387b968ee8", size = 3961, upload-time = "2021-04-28T09:22:54.56Z" } - [[package]] name = "alibabacloud-tea-openapi" version = "0.4.3" @@ -297,15 +230,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/9e/c394b4e2104766fb28a1e44e3ed36e4c7773b4d05c868e482be99d5635c9/alibabacloud_tea_util-0.3.14-py3-none-any.whl", hash = "sha256:10d3e5c340d8f7ec69dd27345eb2fc5a1dab07875742525edf07bbe86db93bfe", size = 6697, upload-time = "2025-11-19T06:01:07.355Z" }, ] -[[package]] -name = "alibabacloud-tea-xml" -version = "0.0.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "alibabacloud-tea" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/32/eb/5e82e419c3061823f3feae9b5681588762929dc4da0176667297c2784c1a/alibabacloud_tea_xml-0.0.3.tar.gz", hash = "sha256:979cb51fadf43de77f41c69fc69c12529728919f849723eb0cd24eb7b048a90c", size = 3466, upload-time = "2025-07-01T08:04:55.144Z" } - [[package]] name = "aliyun-log-python-sdk" version = "0.9.37" @@ -570,28 +494,28 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.38.2" +version = "1.38.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/a3/20aa7c4e83f2f614e0036300f3c352775dede0655c66814da16c37b661a9/basedpyright-1.38.2.tar.gz", hash = "sha256:b433b2b8ba745ed7520cdc79a29a03682f3fb00346d272ece5944e9e5e5daa92", size = 25277019, upload-time = "2026-02-26T11:18:43.594Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/58/7abba2c743571a42b2548f07aee556ebc1e4d0bc2b277aeba1ee6c83b0af/basedpyright-1.38.3.tar.gz", hash = "sha256:9725419786afbfad8a9539527f162da02d462afad440b0412fdb3f3cdf179b90", size = 25277430, upload-time = "2026-03-17T13:10:41.526Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/12/736cab83626fea3fe65cdafb3ef3d2ee9480c56723f2fd33921537289a5e/basedpyright-1.38.2-py3-none-any.whl", hash = "sha256:153481d37fd19f9e3adedc8629d1d071b10c5f5e49321fb026b74444b7c70e24", size = 12312475, upload-time = "2026-02-26T11:18:40.373Z" }, + { url = "https://files.pythonhosted.org/packages/2c/e3/3ebb5c23bd3abb5fc2053b8a06a889aa5c1cf8cff738c78cb6c1957e90cd/basedpyright-1.38.3-py3-none-any.whl", hash = "sha256:1f15c2e489c67d6c5e896c24b6a63251195c04223a55e4568b8f8e8ed49ca830", size = 12313363, upload-time = "2026-03-17T13:10:47.344Z" }, ] [[package]] name = "bce-python-sdk" -version = "0.9.63" +version = "0.9.64" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/ab/4c2927b01a97562af6a296b722eee79658335795f341a395a12742d5e1a3/bce_python_sdk-0.9.63.tar.gz", hash = "sha256:0c80bc3ac128a0a144bae3b8dff1f397f42c30b36f7677e3a39d8df8e77b1088", size = 284419, upload-time = "2026-03-06T14:54:06.592Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/33/047e9c1a6c97e0cd4d93a6490abd8fbc2ccd13569462fc0228699edc08bc/bce_python_sdk-0.9.64.tar.gz", hash = "sha256:901bf787c26ad35855a80d65e58d7584c8541f7f0f2af20847830e572e5b622e", size = 287125, upload-time = "2026-03-17T11:24:29.345Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/a4/501e978776c7060aa8ba77e68536597e754d938bcdbe1826618acebfbddf/bce_python_sdk-0.9.63-py3-none-any.whl", hash = "sha256:ec66eee8807c6aa4036412592da7e8c9e2cd7fdec494190986288ac2195d8276", size = 400305, upload-time = "2026-03-06T14:53:52.887Z" }, + { url = "https://files.pythonhosted.org/packages/48/7f/dd289582f37ab4effea47b2a8503880db4781ca0fc8e0a8ed5ff493359e5/bce_python_sdk-0.9.64-py3-none-any.whl", hash = "sha256:eaad97e4f0e7d613ae978da3cdc5294e9f724ffca2735f79820037fa1317cd6d", size = 402233, upload-time = "2026-03-17T11:24:24.673Z" }, ] [[package]] @@ -660,14 +584,14 @@ wheels = [ [[package]] name = "bleach" -version = "6.2.0" +version = "6.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "webencodings" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/9a/0e33f5054c54d349ea62c277191c020c2d6ef1d65ab2cb1993f91ec846d1/bleach-6.2.0.tar.gz", hash = "sha256:123e894118b8a599fd80d3ec1a6d4cc7ce4e5882b1317a7e1ba69b56e95f991f", size = 203083, upload-time = "2024-10-29T18:30:40.477Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/55/96142937f66150805c25c4d0f31ee4132fd33497753400734f9dfdcbdc66/bleach-6.2.0-py3-none-any.whl", hash = "sha256:117d9c6097a7c3d22fd578fcd8d35ff1e125df6736f554da4e432fdd63f31e5e", size = 163406, upload-time = "2024-10-29T18:30:38.186Z" }, + { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, ] [[package]] @@ -706,30 +630,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.68" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/ae/60c642aa5413e560b671da825329f510b29a77274ed0f580bde77562294d/boto3-1.42.68.tar.gz", hash = "sha256:3f349f967ab38c23425626d130962bcb363e75f042734fe856ea8c5a00eef03c", size = 112761, upload-time = "2026-03-13T19:32:17.137Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/8b/d00575be514744ca4839e7d85bf4a8a3c7b6b4574433291e58d14c68ae09/boto3-1.42.73.tar.gz", hash = "sha256:d37b58d6cd452ca808dd6823ae19ca65b6244096c5125ef9052988b337298bae", size = 112775, upload-time = "2026-03-20T19:39:52.814Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/f6/dc6e993479dbb597d68223fbf61cb026511737696b15bd7d2a33e9b2c24f/boto3-1.42.68-py3-none-any.whl", hash = "sha256:dbff353eb7dc93cbddd7926ed24793e0174c04adbe88860dfa639568442e4962", size = 140556, upload-time = "2026-03-13T19:32:14.951Z" }, + { url = "https://files.pythonhosted.org/packages/aa/05/1fcf03d90abaa3d0b42a6bfd10231dd709493ecbacf794aa2eea5eae6841/boto3-1.42.73-py3-none-any.whl", hash = "sha256:1f81b79b873f130eeab14bb556417a7c66d38f3396b7f2fe3b958b3f9094f455", size = 140556, upload-time = "2026-03-20T19:39:50.298Z" }, ] [[package]] name = "boto3-stubs" -version = "1.42.68" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/8c/dd4b0c95ff008bed5a35ab411452ece121b355539d2a0b6dcd62a0c47be5/boto3_stubs-1.42.68.tar.gz", hash = "sha256:96ad1020735619483fb9b4da7a5e694b460bf2e18f84a34d5d175d0ffe8c4653", size = 101372, upload-time = "2026-03-13T19:49:54.867Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/c3/fcc47102c63278af25ad57c93d97dc393f4dbc54c0117a29c78f2b96ec1e/boto3_stubs-1.42.73.tar.gz", hash = "sha256:36f625769b5505c4bc627f16244b98de9e10dae3ac36f1aa0f0ebe2f201dc138", size = 101373, upload-time = "2026-03-20T19:59:51.463Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/15/3ca5848917214a168134512a5b45f856a56e913659888947a052e02031b5/boto3_stubs-1.42.68-py3-none-any.whl", hash = "sha256:ed7f98334ef7b2377fa8532190e63dc2c6d1dc895e3d7cb3d6d1c83771b81bf6", size = 70011, upload-time = "2026-03-13T19:49:42.801Z" }, + { url = "https://files.pythonhosted.org/packages/4b/57/d570ba61a2a0c7fe0c8667e41269a0480293cb53e1786d6661a2bd827fc5/boto3_stubs-1.42.73-py3-none-any.whl", hash = "sha256:bd658429069d8215247fc3abc003220cd875c24ab6eda7b3405090408afaacdf", size = 70009, upload-time = "2026-03-20T19:59:43.786Z" }, ] [package.optional-dependencies] @@ -739,16 +663,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.42.68" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3f/22/87502d5fbbfa8189406a617b30b1e2a3dc0ab2669f7268e91b385c1c1c7a/botocore-1.42.68.tar.gz", hash = "sha256:3951c69e12ac871dda245f48dac5c7dd88ea1bfdd74a8879ec356cf2874b806a", size = 14994514, upload-time = "2026-03-13T19:32:03.577Z" } +sdist = { url = "https://files.pythonhosted.org/packages/28/23/0c88ca116ef63b1ae77c901cd5d2095d22a8dbde9e80df74545db4a061b4/botocore-1.42.73.tar.gz", hash = "sha256:575858641e4949aaf2af1ced145b8524529edf006d075877af6b82ff96ad854c", size = 15008008, upload-time = "2026-03-20T19:39:40.082Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/2a/1428f6594799780fe6ee845d8e6aeffafe026cd16a70c878684e2dcbbfc8/botocore-1.42.68-py3-none-any.whl", hash = "sha256:9df7da26374601f890e2f115bfa573d65bf15b25fe136bb3aac809f6145f52ab", size = 14668816, upload-time = "2026-03-13T19:31:58.572Z" }, + { url = "https://files.pythonhosted.org/packages/8e/65/971f3d55015f4d133a6ff3ad74cd39f4b8dd8f53f7775a3c2ad378ea5145/botocore-1.42.73-py3-none-any.whl", hash = "sha256:7b62e2a12f7a1b08eb7360eecd23bb16fe3b7ab7f5617cf91b25476c6f86a0fe", size = 14681861, upload-time = "2026-03-20T19:39:35.341Z" }, ] [[package]] @@ -1290,41 +1214,41 @@ wheels = [ [[package]] name = "coverage" -version = "7.13.4" +version = "7.13.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/56/95b7e30fa389756cb56630faa728da46a27b8c6eb46f9d557c68fff12b65/coverage-7.13.4.tar.gz", hash = "sha256:e5c8f6ed1e61a8b2dcdf31eb0b9bbf0130750ca79c1c49eb898e2ad86f5ccc91", size = 827239, upload-time = "2026-02-09T12:59:03.86Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179", size = 915967, upload-time = "2026-03-17T10:33:18.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/ad/b59e5b451cf7172b8d1043dc0fa718f23aab379bc1521ee13d4bd9bfa960/coverage-7.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d490ba50c3f35dd7c17953c68f3270e7ccd1c6642e2d2afe2d8e720b98f5a053", size = 219278, upload-time = "2026-02-09T12:56:31.673Z" }, - { url = "https://files.pythonhosted.org/packages/f1/17/0cb7ca3de72e5f4ef2ec2fa0089beafbcaaaead1844e8b8a63d35173d77d/coverage-7.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19bc3c88078789f8ef36acb014d7241961dbf883fd2533d18cb1e7a5b4e28b11", size = 219783, upload-time = "2026-02-09T12:56:33.104Z" }, - { url = "https://files.pythonhosted.org/packages/ab/63/325d8e5b11e0eaf6d0f6a44fad444ae58820929a9b0de943fa377fe73e85/coverage-7.13.4-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3998e5a32e62fdf410c0dbd3115df86297995d6e3429af80b8798aad894ca7aa", size = 250200, upload-time = "2026-02-09T12:56:34.474Z" }, - { url = "https://files.pythonhosted.org/packages/76/53/c16972708cbb79f2942922571a687c52bd109a7bd51175aeb7558dff2236/coverage-7.13.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e264226ec98e01a8e1054314af91ee6cde0eacac4f465cc93b03dbe0bce2fd7", size = 252114, upload-time = "2026-02-09T12:56:35.749Z" }, - { url = "https://files.pythonhosted.org/packages/eb/c2/7ab36d8b8cc412bec9ea2d07c83c48930eb4ba649634ba00cb7e4e0f9017/coverage-7.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a3aa4e7b9e416774b21797365b358a6e827ffadaaca81b69ee02946852449f00", size = 254220, upload-time = "2026-02-09T12:56:37.796Z" }, - { url = "https://files.pythonhosted.org/packages/d6/4d/cf52c9a3322c89a0e6febdfbc83bb45c0ed3c64ad14081b9503adee702e7/coverage-7.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:71ca20079dd8f27fcf808817e281e90220475cd75115162218d0e27549f95fef", size = 256164, upload-time = "2026-02-09T12:56:39.016Z" }, - { url = "https://files.pythonhosted.org/packages/78/e9/eb1dd17bd6de8289df3580e967e78294f352a5df8a57ff4671ee5fc3dcd0/coverage-7.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e2f25215f1a359ab17320b47bcdaca3e6e6356652e8256f2441e4ef972052903", size = 250325, upload-time = "2026-02-09T12:56:40.668Z" }, - { url = "https://files.pythonhosted.org/packages/71/07/8c1542aa873728f72267c07278c5cc0ec91356daf974df21335ccdb46368/coverage-7.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d65b2d373032411e86960604dc4edac91fdfb5dca539461cf2cbe78327d1e64f", size = 251913, upload-time = "2026-02-09T12:56:41.97Z" }, - { url = "https://files.pythonhosted.org/packages/74/d7/c62e2c5e4483a748e27868e4c32ad3daa9bdddbba58e1bc7a15e252baa74/coverage-7.13.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94eb63f9b363180aff17de3e7c8760c3ba94664ea2695c52f10111244d16a299", size = 249974, upload-time = "2026-02-09T12:56:43.323Z" }, - { url = "https://files.pythonhosted.org/packages/98/9f/4c5c015a6e98ced54efd0f5cf8d31b88e5504ecb6857585fc0161bb1e600/coverage-7.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e856bf6616714c3a9fbc270ab54103f4e685ba236fa98c054e8f87f266c93505", size = 253741, upload-time = "2026-02-09T12:56:45.155Z" }, - { url = "https://files.pythonhosted.org/packages/bd/59/0f4eef89b9f0fcd9633b5d350016f54126ab49426a70ff4c4e87446cabdc/coverage-7.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:65dfcbe305c3dfe658492df2d85259e0d79ead4177f9ae724b6fb245198f55d6", size = 249695, upload-time = "2026-02-09T12:56:46.636Z" }, - { url = "https://files.pythonhosted.org/packages/b5/2c/b7476f938deb07166f3eb281a385c262675d688ff4659ad56c6c6b8e2e70/coverage-7.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b507778ae8a4c915436ed5c2e05b4a6cecfa70f734e19c22a005152a11c7b6a9", size = 250599, upload-time = "2026-02-09T12:56:48.13Z" }, - { url = "https://files.pythonhosted.org/packages/b8/34/c3420709d9846ee3785b9f2831b4d94f276f38884032dca1457fa83f7476/coverage-7.13.4-cp311-cp311-win32.whl", hash = "sha256:784fc3cf8be001197b652d51d3fd259b1e2262888693a4636e18879f613a62a9", size = 221780, upload-time = "2026-02-09T12:56:50.479Z" }, - { url = "https://files.pythonhosted.org/packages/61/08/3d9c8613079d2b11c185b865de9a4c1a68850cfda2b357fae365cf609f29/coverage-7.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:2421d591f8ca05b308cf0092807308b2facbefe54af7c02ac22548b88b95c98f", size = 222715, upload-time = "2026-02-09T12:56:51.815Z" }, - { url = "https://files.pythonhosted.org/packages/18/1a/54c3c80b2f056164cc0a6cdcb040733760c7c4be9d780fe655f356f433e4/coverage-7.13.4-cp311-cp311-win_arm64.whl", hash = "sha256:79e73a76b854d9c6088fe5d8b2ebe745f8681c55f7397c3c0a016192d681045f", size = 221385, upload-time = "2026-02-09T12:56:53.194Z" }, - { url = "https://files.pythonhosted.org/packages/d1/81/4ce2fdd909c5a0ed1f6dedb88aa57ab79b6d1fbd9b588c1ac7ef45659566/coverage-7.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02231499b08dabbe2b96612993e5fc34217cdae907a51b906ac7fca8027a4459", size = 219449, upload-time = "2026-02-09T12:56:54.889Z" }, - { url = "https://files.pythonhosted.org/packages/5d/96/5238b1efc5922ddbdc9b0db9243152c09777804fb7c02ad1741eb18a11c0/coverage-7.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40aa8808140e55dc022b15d8aa7f651b6b3d68b365ea0398f1441e0b04d859c3", size = 219810, upload-time = "2026-02-09T12:56:56.33Z" }, - { url = "https://files.pythonhosted.org/packages/78/72/2f372b726d433c9c35e56377cf1d513b4c16fe51841060d826b95caacec1/coverage-7.13.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5b856a8ccf749480024ff3bd7310adaef57bf31fd17e1bfc404b7940b6986634", size = 251308, upload-time = "2026-02-09T12:56:57.858Z" }, - { url = "https://files.pythonhosted.org/packages/5d/a0/2ea570925524ef4e00bb6c82649f5682a77fac5ab910a65c9284de422600/coverage-7.13.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c048ea43875fbf8b45d476ad79f179809c590ec7b79e2035c662e7afa3192e3", size = 254052, upload-time = "2026-02-09T12:56:59.754Z" }, - { url = "https://files.pythonhosted.org/packages/e8/ac/45dc2e19a1939098d783c846e130b8f862fbb50d09e0af663988f2f21973/coverage-7.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b7b38448866e83176e28086674fe7368ab8590e4610fb662b44e345b86d63ffa", size = 255165, upload-time = "2026-02-09T12:57:01.287Z" }, - { url = "https://files.pythonhosted.org/packages/2d/4d/26d236ff35abc3b5e63540d3386e4c3b192168c1d96da5cb2f43c640970f/coverage-7.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:de6defc1c9badbf8b9e67ae90fd00519186d6ab64e5cc5f3d21359c2a9b2c1d3", size = 257432, upload-time = "2026-02-09T12:57:02.637Z" }, - { url = "https://files.pythonhosted.org/packages/ec/55/14a966c757d1348b2e19caf699415a2a4c4f7feaa4bbc6326a51f5c7dd1b/coverage-7.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7eda778067ad7ffccd23ecffce537dface96212576a07924cbf0d8799d2ded5a", size = 251716, upload-time = "2026-02-09T12:57:04.056Z" }, - { url = "https://files.pythonhosted.org/packages/77/33/50116647905837c66d28b2af1321b845d5f5d19be9655cb84d4a0ea806b4/coverage-7.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e87f6c587c3f34356c3759f0420693e35e7eb0e2e41e4c011cb6ec6ecbbf1db7", size = 253089, upload-time = "2026-02-09T12:57:05.503Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b4/8efb11a46e3665d92635a56e4f2d4529de6d33f2cb38afd47d779d15fc99/coverage-7.13.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8248977c2e33aecb2ced42fef99f2d319e9904a36e55a8a68b69207fb7e43edc", size = 251232, upload-time = "2026-02-09T12:57:06.879Z" }, - { url = "https://files.pythonhosted.org/packages/51/24/8cd73dd399b812cc76bb0ac260e671c4163093441847ffe058ac9fda1e32/coverage-7.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:25381386e80ae727608e662474db537d4df1ecd42379b5ba33c84633a2b36d47", size = 255299, upload-time = "2026-02-09T12:57:08.245Z" }, - { url = "https://files.pythonhosted.org/packages/03/94/0a4b12f1d0e029ce1ccc1c800944a9984cbe7d678e470bb6d3c6bc38a0da/coverage-7.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:ee756f00726693e5ba94d6df2bdfd64d4852d23b09bb0bc700e3b30e6f333985", size = 250796, upload-time = "2026-02-09T12:57:10.142Z" }, - { url = "https://files.pythonhosted.org/packages/73/44/6002fbf88f6698ca034360ce474c406be6d5a985b3fdb3401128031eef6b/coverage-7.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fdfc1e28e7c7cdce44985b3043bc13bbd9c747520f94a4d7164af8260b3d91f0", size = 252673, upload-time = "2026-02-09T12:57:12.197Z" }, - { url = "https://files.pythonhosted.org/packages/de/c6/a0279f7c00e786be75a749a5674e6fa267bcbd8209cd10c9a450c655dfa7/coverage-7.13.4-cp312-cp312-win32.whl", hash = "sha256:01d4cbc3c283a17fc1e42d614a119f7f438eabb593391283adca8dc86eff1246", size = 221990, upload-time = "2026-02-09T12:57:14.085Z" }, - { url = "https://files.pythonhosted.org/packages/77/4e/c0a25a425fcf5557d9abd18419c95b63922e897bc86c1f327f155ef234a9/coverage-7.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:9401ebc7ef522f01d01d45532c68c5ac40fb27113019b6b7d8b208f6e9baa126", size = 222800, upload-time = "2026-02-09T12:57:15.944Z" }, - { url = "https://files.pythonhosted.org/packages/47/ac/92da44ad9a6f4e3a7debd178949d6f3769bedca33830ce9b1dcdab589a37/coverage-7.13.4-cp312-cp312-win_arm64.whl", hash = "sha256:b1ec7b6b6e93255f952e27ab58fbc68dcc468844b16ecbee881aeb29b6ab4d8d", size = 221415, upload-time = "2026-02-09T12:57:17.497Z" }, - { url = "https://files.pythonhosted.org/packages/0d/4a/331fe2caf6799d591109bb9c08083080f6de90a823695d412a935622abb2/coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0", size = 211242, upload-time = "2026-02-09T12:59:02.032Z" }, + { url = "https://files.pythonhosted.org/packages/4b/37/d24c8f8220ff07b839b2c043ea4903a33b0f455abe673ae3c03bbdb7f212/coverage-7.13.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66a80c616f80181f4d643b0f9e709d97bcea413ecd9631e1dedc7401c8e6695d", size = 219381, upload-time = "2026-03-17T10:30:14.68Z" }, + { url = "https://files.pythonhosted.org/packages/35/8b/cd129b0ca4afe886a6ce9d183c44d8301acbd4ef248622e7c49a23145605/coverage-7.13.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:145ede53ccbafb297c1c9287f788d1bc3efd6c900da23bf6931b09eafc931587", size = 219880, upload-time = "2026-03-17T10:30:16.231Z" }, + { url = "https://files.pythonhosted.org/packages/55/2f/e0e5b237bffdb5d6c530ce87cc1d413a5b7d7dfd60fb067ad6d254c35c76/coverage-7.13.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0672854dc733c342fa3e957e0605256d2bf5934feeac328da9e0b5449634a642", size = 250303, upload-time = "2026-03-17T10:30:17.748Z" }, + { url = "https://files.pythonhosted.org/packages/92/be/b1afb692be85b947f3401375851484496134c5554e67e822c35f28bf2fbc/coverage-7.13.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ec10e2a42b41c923c2209b846126c6582db5e43a33157e9870ba9fb70dc7854b", size = 252218, upload-time = "2026-03-17T10:30:19.804Z" }, + { url = "https://files.pythonhosted.org/packages/da/69/2f47bb6fa1b8d1e3e5d0c4be8ccb4313c63d742476a619418f85740d597b/coverage-7.13.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be3d4bbad9d4b037791794ddeedd7d64a56f5933a2c1373e18e9e568b9141686", size = 254326, upload-time = "2026-03-17T10:30:21.321Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d0/79db81da58965bd29dabc8f4ad2a2af70611a57cba9d1ec006f072f30a54/coverage-7.13.5-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4d2afbc5cc54d286bfb54541aa50b64cdb07a718227168c87b9e2fb8f25e1743", size = 256267, upload-time = "2026-03-17T10:30:23.094Z" }, + { url = "https://files.pythonhosted.org/packages/e5/32/d0d7cc8168f91ddab44c0ce4806b969df5f5fdfdbb568eaca2dbc2a04936/coverage-7.13.5-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3ad050321264c49c2fa67bb599100456fc51d004b82534f379d16445da40fb75", size = 250430, upload-time = "2026-03-17T10:30:25.311Z" }, + { url = "https://files.pythonhosted.org/packages/4d/06/a055311d891ddbe231cd69fdd20ea4be6e3603ffebddf8704b8ca8e10a3c/coverage-7.13.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7300c8a6d13335b29bb76d7651c66af6bd8658517c43499f110ddc6717bfc209", size = 252017, upload-time = "2026-03-17T10:30:27.284Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f6/d0fd2d21e29a657b5f77a2fe7082e1568158340dceb941954f776dce1b7b/coverage-7.13.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:eb07647a5738b89baab047f14edd18ded523de60f3b30e75c2acc826f79c839a", size = 250080, upload-time = "2026-03-17T10:30:29.481Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ab/0d7fb2efc2e9a5eb7ddcc6e722f834a69b454b7e6e5888c3a8567ecffb31/coverage-7.13.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9adb6688e3b53adffefd4a52d72cbd8b02602bfb8f74dcd862337182fd4d1a4e", size = 253843, upload-time = "2026-03-17T10:30:31.301Z" }, + { url = "https://files.pythonhosted.org/packages/ba/6f/7467b917bbf5408610178f62a49c0ed4377bb16c1657f689cc61470da8ce/coverage-7.13.5-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7c8d4bc913dd70b93488d6c496c77f3aff5ea99a07e36a18f865bca55adef8bd", size = 249802, upload-time = "2026-03-17T10:30:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/75/2c/1172fb689df92135f5bfbbd69fc83017a76d24ea2e2f3a1154007e2fb9f8/coverage-7.13.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0e3c426ffc4cd952f54ee9ffbdd10345709ecc78a3ecfd796a57236bfad0b9b8", size = 250707, upload-time = "2026-03-17T10:30:35.2Z" }, + { url = "https://files.pythonhosted.org/packages/67/21/9ac389377380a07884e3b48ba7a620fcd9dbfaf1d40565facdc6b36ec9ef/coverage-7.13.5-cp311-cp311-win32.whl", hash = "sha256:259b69bb83ad9894c4b25be2528139eecba9a82646ebdda2d9db1ba28424a6bf", size = 221880, upload-time = "2026-03-17T10:30:36.775Z" }, + { url = "https://files.pythonhosted.org/packages/af/7f/4cd8a92531253f9d7c1bbecd9fa1b472907fb54446ca768c59b531248dc5/coverage-7.13.5-cp311-cp311-win_amd64.whl", hash = "sha256:258354455f4e86e3e9d0d17571d522e13b4e1e19bf0f8596bcf9476d61e7d8a9", size = 222816, upload-time = "2026-03-17T10:30:38.891Z" }, + { url = "https://files.pythonhosted.org/packages/12/a6/1d3f6155fb0010ca68eba7fe48ca6c9da7385058b77a95848710ecf189b1/coverage-7.13.5-cp311-cp311-win_arm64.whl", hash = "sha256:bff95879c33ec8da99fc9b6fe345ddb5be6414b41d6d1ad1c8f188d26f36e028", size = 221483, upload-time = "2026-03-17T10:30:40.463Z" }, + { url = "https://files.pythonhosted.org/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01", size = 219554, upload-time = "2026-03-17T10:30:42.208Z" }, + { url = "https://files.pythonhosted.org/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422", size = 219908, upload-time = "2026-03-17T10:30:43.906Z" }, + { url = "https://files.pythonhosted.org/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f", size = 251419, upload-time = "2026-03-17T10:30:45.545Z" }, + { url = "https://files.pythonhosted.org/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5", size = 254159, upload-time = "2026-03-17T10:30:47.204Z" }, + { url = "https://files.pythonhosted.org/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376", size = 255270, upload-time = "2026-03-17T10:30:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256", size = 257538, upload-time = "2026-03-17T10:30:50.77Z" }, + { url = "https://files.pythonhosted.org/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c", size = 251821, upload-time = "2026-03-17T10:30:52.5Z" }, + { url = "https://files.pythonhosted.org/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5", size = 253191, upload-time = "2026-03-17T10:30:54.543Z" }, + { url = "https://files.pythonhosted.org/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09", size = 251337, upload-time = "2026-03-17T10:30:56.663Z" }, + { url = "https://files.pythonhosted.org/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9", size = 255404, upload-time = "2026-03-17T10:30:58.427Z" }, + { url = "https://files.pythonhosted.org/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf", size = 250903, upload-time = "2026-03-17T10:31:00.093Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c", size = 252780, upload-time = "2026-03-17T10:31:01.916Z" }, + { url = "https://files.pythonhosted.org/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf", size = 222093, upload-time = "2026-03-17T10:31:03.642Z" }, + { url = "https://files.pythonhosted.org/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810", size = 222900, upload-time = "2026-03-17T10:31:05.651Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de", size = 221515, upload-time = "2026-03-17T10:31:07.293Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, ] [package.optional-dependencies] @@ -1533,7 +1457,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.13.2" +version = "1.13.3" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1605,6 +1529,7 @@ dependencies = [ { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pypandoc" }, { name = "pypdfium2" }, { name = "python-docx" }, { name = "python-dotenv" }, @@ -1743,8 +1668,8 @@ requires-dist = [ { name = "arize-phoenix-otel", specifier = "~=0.15.0" }, { name = "azure-identity", specifier = "==1.25.3" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, - { name = "bleach", specifier = "~=6.2.0" }, - { name = "boto3", specifier = "==1.42.68" }, + { name = "bleach", specifier = "~=6.3.0" }, + { name = "boto3", specifier = "==1.42.73" }, { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.6.2" }, @@ -1762,7 +1687,7 @@ requires-dist = [ { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.3.0" }, { name = "google-api-core", specifier = ">=2.19.1" }, - { name = "google-api-python-client", specifier = "==2.192.0" }, + { name = "google-api-python-client", specifier = "==2.193.0" }, { name = "google-auth", specifier = ">=2.47.0" }, { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, @@ -1775,7 +1700,7 @@ requires-dist = [ { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.7.16" }, - { name = "litellm", specifier = "==1.82.2" }, + { name = "litellm", specifier = "==1.82.6" }, { name = "markdown", specifier = "~=3.10.2" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, @@ -1807,18 +1732,19 @@ requires-dist = [ { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, + { name = "pypandoc", specifier = "~=1.13" }, { name = "pypdfium2", specifier = "==5.6.0" }, { name = "python-docx", specifier = "~=1.2.0" }, { name = "python-dotenv", specifier = "==1.2.2" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, { name = "redis", extras = ["hiredis"], specifier = "~=7.3.0" }, - { name = "resend", specifier = "~=2.23.0" }, + { name = "resend", specifier = "~=2.26.0" }, { name = "sendgrid", specifier = "~=6.12.3" }, - { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.54.0" }, + { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, { name = "sqlalchemy", specifier = "~=2.0.29" }, { name = "sseclient-py", specifier = "~=1.9.0" }, - { name = "starlette", specifier = "==0.52.1" }, + { name = "starlette", specifier = "==1.0.0" }, { name = "tiktoken", specifier = "~=0.12.0" }, { name = "transformers", specifier = "~=5.3.0" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, @@ -1841,10 +1767,10 @@ dev = [ { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.19.1" }, { name = "pandas-stubs", specifier = "~=3.0.0" }, - { name = "pyrefly", specifier = ">=0.55.0" }, + { name = "pyrefly", specifier = ">=0.57.1" }, { name = "pytest", specifier = "~=9.0.2" }, { name = "pytest-benchmark", specifier = "~=5.2.3" }, - { name = "pytest-cov", specifier = "~=7.0.0" }, + { name = "pytest-cov", specifier = "~=7.1.0" }, { name = "pytest-env", specifier = "~=1.6.0" }, { name = "pytest-mock", specifier = "~=3.15.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, @@ -1910,7 +1836,7 @@ tools = [ { name = "nltk", specifier = "~=3.9.1" }, ] vdb = [ - { name = "alibabacloud-gpdb20160503", specifier = "~=3.8.0" }, + { name = "alibabacloud-gpdb20160503", specifier = "~=5.1.0" }, { name = "alibabacloud-tea-openapi", specifier = "~=0.4.3" }, { name = "chromadb", specifier = "==0.5.20" }, { name = "clickhouse-connect", specifier = "~=0.14.1" }, @@ -2499,7 +2425,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.192.0" +version = "2.193.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2508,9 +2434,9 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/d8/489052a40935e45b9b5b3d6accc14b041360c1507bdc659c2e1a19aaa3ff/google_api_python_client-2.192.0.tar.gz", hash = "sha256:d48cfa6078fadea788425481b007af33fe0ab6537b78f37da914fb6fc112eb27", size = 14209505, upload-time = "2026-03-05T15:17:01.598Z" } +sdist = { url = "https://files.pythonhosted.org/packages/90/f4/e14b6815d3b1885328dd209676a3a4c704882743ac94e18ef0093894f5c8/google_api_python_client-2.193.0.tar.gz", hash = "sha256:8f88d16e89d11341e0a8b199cafde0fb7e6b44260dffb88d451577cbd1bb5d33", size = 14281006, upload-time = "2026-03-17T18:25:29.415Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/76/ec4128f00fefb9011635ae2abc67d7dacd05c8559378f8f05f0c907c38d8/google_api_python_client-2.192.0-py3-none-any.whl", hash = "sha256:63a57d4457cd97df1d63eb89c5fda03c5a50588dcbc32c0115dd1433c08f4b62", size = 14783267, upload-time = "2026-03-05T15:16:58.804Z" }, + { url = "https://files.pythonhosted.org/packages/f0/6d/fe75167797790a56d17799b75e1129bb93f7ff061efc7b36e9731bd4be2b/google_api_python_client-2.193.0-py3-none-any.whl", hash = "sha256:c42aa324b822109901cfecab5dc4fc3915d35a7b376835233c916c70610322db", size = 14856490, upload-time = "2026-03-17T18:25:26.608Z" }, ] [[package]] @@ -2546,7 +2472,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.141.0" +version = "1.142.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2562,9 +2488,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, + { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, ] [[package]] @@ -2617,7 +2543,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.9.0" +version = "3.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2627,9 +2553,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/b1/4f0798e88285b50dfc60ed3a7de071def538b358db2da468c2e0deecbb40/google_cloud_storage-3.9.0.tar.gz", hash = "sha256:f2d8ca7db2f652be757e92573b2196e10fbc09649b5c016f8b422ad593c641cc", size = 17298544, upload-time = "2026-02-02T13:36:34.119Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/0b/816a6ae3c9fd096937d2e5f9670558908811d57d59ddf69dd4b83b326fd1/google_cloud_storage-3.9.0-py3-none-any.whl", hash = "sha256:2dce75a9e8b3387078cbbdad44757d410ecdb916101f8ba308abf202b6968066", size = 321324, upload-time = "2026-02-02T13:36:32.271Z" }, + { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, ] [[package]] @@ -3458,7 +3384,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.7.17" +version = "0.7.22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -3471,9 +3397,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/79/81041dde07a974e728db7def23c1c7255950b8874102925cc77093bc847d/langsmith-0.7.17.tar.gz", hash = "sha256:6c1b0c2863cdd6636d2a58b8d5b1b80060703d98cac2593f4233e09ac25b5a9d", size = 1132228, upload-time = "2026-03-12T20:41:10.808Z" } +sdist = { url = "https://files.pythonhosted.org/packages/be/2a/2d5e6c67396fd228670af278c4da7bd6db2b8d11deaf6f108490b6d3f561/langsmith-0.7.22.tar.gz", hash = "sha256:35bfe795d648b069958280760564632fd28ebc9921c04f3e209c0db6a6c7dc04", size = 1134923, upload-time = "2026-03-19T22:45:23.492Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/31/62689d57f4d25792bd6a3c05c868771899481be2f3e31f9e71d31e1ac4ab/langsmith-0.7.17-py3-none-any.whl", hash = "sha256:cbec10460cb6c6ecc94c18c807be88a9984838144ae6c4693c9f859f378d7d02", size = 359147, upload-time = "2026-03-12T20:41:08.758Z" }, + { url = "https://files.pythonhosted.org/packages/1a/94/1f5d72655ab6534129540843776c40eff757387b88e798d8b3bf7e313fd4/langsmith-0.7.22-py3-none-any.whl", hash = "sha256:6e9d5148314d74e86748cb9d3898632cad0320c9323d95f70f969e5bc078eee4", size = 359927, upload-time = "2026-03-19T22:45:21.603Z" }, ] [[package]] @@ -3521,7 +3447,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.82.2" +version = "1.82.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3537,9 +3463,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/12/010a86643f12ac0b004032d5927c260094299a84ed38b5ed20a8f8c7e3c4/litellm-1.82.2.tar.gz", hash = "sha256:f5f4c4049f344a88bf80b2e421bb927807687c99624515d7ff4152d533ec9dcb", size = 17353218, upload-time = "2026-03-13T21:24:24.5Z" } +sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/e4/87e3ca82a8bf6e6bfffb42a539a1350dd6ced1b7169397bd439ba56fde10/litellm-1.82.2-py3-none-any.whl", hash = "sha256:641ed024774fa3d5b4dd9347f0efb1e31fa422fba2a6500aabedee085d1194cb", size = 15524224, upload-time = "2026-03-13T21:24:21.288Z" }, + { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, ] [[package]] @@ -3979,7 +3905,7 @@ wheels = [ [[package]] name = "nltk" -version = "3.9.3" +version = "3.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3987,9 +3913,9 @@ dependencies = [ { name = "regex" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" }, + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, ] [[package]] @@ -4536,7 +4462,7 @@ wheels = [ [[package]] name = "opik" -version = "1.10.39" +version = "1.10.45" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4555,9 +4481,9 @@ dependencies = [ { name = "tqdm" }, { name = "uuid6" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/0f/b1e00a18cac16b4f36bf6cecc2de962fda810a9416d1159c48f46b81f5ec/opik-1.10.39.tar.gz", hash = "sha256:4d808eb2137070fc5d92a3bed3c3100d9cccfb35f4f0b71ea9990733f293dbb2", size = 780312, upload-time = "2026-03-12T14:08:25.746Z" } +sdist = { url = "https://files.pythonhosted.org/packages/85/17/edea6308347cec62e6828de7c573c596559c502b54fa4f0c88a52e2e81f5/opik-1.10.45.tar.gz", hash = "sha256:d8d8627ba03d12def46965e03d58f611daaf5cf878b3d087c53fe1159788c140", size = 789876, upload-time = "2026-03-20T11:35:12.457Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/24/0f4404907a98b4aec4508504570a78a61a3a8b5e451c67326632695ba8e6/opik-1.10.39-py3-none-any.whl", hash = "sha256:a72d735b9afac62e5262294b2f704aca89ec31f5c9beda17504815f7423870c3", size = 1317833, upload-time = "2026-03-12T14:08:23.954Z" }, + { url = "https://files.pythonhosted.org/packages/b7/17/150e9eecfa28cb23f7a0bfe83ae1486a11022b97fe6d12328b455784658d/opik-1.10.45-py3-none-any.whl", hash = "sha256:e8050d9e5e0d92ff587f156eacbdd02099897f39cfe79a98380b6c8ae9906b95", size = 1337714, upload-time = "2026-03-20T11:35:10.237Z" }, ] [[package]] @@ -5273,15 +5199,15 @@ wheels = [ [[package]] name = "pydantic-extra-types" -version = "2.11.0" +version = "2.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/35/2fee58b1316a73e025728583d3b1447218a97e621933fc776fb8c0f2ebdd/pydantic_extra_types-2.11.0.tar.gz", hash = "sha256:4e9991959d045b75feb775683437a97991d02c138e00b59176571db9ce634f0e", size = 157226, upload-time = "2025-12-31T16:18:27.944Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/17/fabd56da47096d240dd45ba627bead0333b0cf0ee8ada9bec579287dadf3/pydantic_extra_types-2.11.0-py3-none-any.whl", hash = "sha256:84b864d250a0fc62535b7ec591e36f2c5b4d1325fa0017eb8cda9aeb63b374a6", size = 74296, upload-time = "2025-12-31T16:18:26.38Z" }, + { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, ] [[package]] @@ -5380,6 +5306,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/7d/037401cecb34728d1c28ea05e196ea3c9d50a1ce0f2172e586e075ff55d8/pyobvector-0.2.25-py3-none-any.whl", hash = "sha256:ae0153f99bd0222783ed7e3951efc31a0d2b462d926b6f86ebd2033409aede8f", size = 64663, upload-time = "2026-03-10T07:18:29.789Z" }, ] +[[package]] +name = "pypandoc" +version = "1.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/d6/410615fc433e5d1eacc00db2044ae2a9c82302df0d35366fe2bd15de024d/pypandoc-1.17.tar.gz", hash = "sha256:51179abfd6e582a25ed03477541b48836b5bba5a4c3b282a547630793934d799", size = 69071, upload-time = "2026-03-14T22:39:07.21Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/86/e2ffa604eacfbec3f430b1d850e7e04c4101eca1a5828f9ae54bf51dfba4/pypandoc-1.17-py3-none-any.whl", hash = "sha256:01fdbffa61edb9f8e82e8faad6954efcb7b6f8f0634aead4d89e322a00225a67", size = 23554, upload-time = "2026-03-14T22:38:46.007Z" }, +] + [[package]] name = "pypandoc-binary" version = "1.17" @@ -5405,11 +5340,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.9.1" +version = "6.9.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" }, ] [[package]] @@ -5467,18 +5402,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.55.0" +version = "0.57.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" }, - { url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" }, - { url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" }, - { url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" }, - { url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" }, - { url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" }, - { url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" }, + { url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" }, + { url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" }, + { url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" }, + { url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" }, + { url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" }, + { url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" }, ] [[package]] @@ -5512,16 +5447,16 @@ wheels = [ [[package]] name = "pytest-cov" -version = "7.0.0" +version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "coverage", extra = ["toml"] }, { name = "pluggy" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/51/a849f96e117386044471c8ec2bd6cfebacda285da9525c9106aeb28da671/pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2", size = 55592, upload-time = "2026-03-21T20:11:16.284Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, + { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" }, ] [[package]] @@ -5917,7 +5852,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -5925,9 +5860,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] @@ -5957,15 +5892,15 @@ wheels = [ [[package]] name = "resend" -version = "2.23.0" +version = "2.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/a3/20003e7d14604fef778bd30c69604df3560a657a95a5c29a9688610759b6/resend-2.23.0.tar.gz", hash = "sha256:df613827dcc40eb1c9de2e5ff600cd4081b89b206537dec8067af1a5016d23c7", size = 31416, upload-time = "2026-02-23T19:01:57.603Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/ff/6a4e5e758fc2145c6a7d8563934d8ee24bf96a0212d7ec7d1af1f155bb74/resend-2.26.0.tar.gz", hash = "sha256:957a6a59dc597ce27fbd6d5383220dd9cc497fab99d4f3d775c8a42a449a569e", size = 36238, upload-time = "2026-03-20T22:49:09.728Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/35/64df775b8cd95e89798fd7b1b7fcafa975b6b09f559c10c0650e65b33580/resend-2.23.0-py2.py3-none-any.whl", hash = "sha256:eca6d28a1ffd36c1fc489fa83cb6b511f384792c9f07465f7c92d96c8b4d5636", size = 52599, upload-time = "2026-02-23T19:01:55.962Z" }, + { url = "https://files.pythonhosted.org/packages/16/c2/f88d3299d97aa1d36a923d0846fe185fcf5355ca898c954b2e5a79f090b5/resend-2.26.0-py2.py3-none-any.whl", hash = "sha256:5e25a804a84a68df504f2ade5369ac37e0139e37788a1f20b66c88696595b4bc", size = 57699, upload-time = "2026-03-20T22:49:08.354Z" }, ] [[package]] @@ -6046,27 +5981,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.6" +version = "0.15.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/df/f8629c19c5318601d3121e230f74cbee7a3732339c52b21daa2b82ef9c7d/ruff-0.15.6.tar.gz", hash = "sha256:8394c7bb153a4e3811a4ecdacd4a8e6a4fa8097028119160dffecdcdf9b56ae4", size = 4597916, upload-time = "2026-03-12T23:05:47.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/22/9e4f66ee588588dc6c9af6a994e12d26e19efbe874d1a909d09a6dac7a59/ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac", size = 4601277, upload-time = "2026-03-19T16:26:22.605Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/2f/4e03a7e5ce99b517e98d3b4951f411de2b0fa8348d39cf446671adcce9a2/ruff-0.15.6-py3-none-linux_armv6l.whl", hash = "sha256:7c98c3b16407b2cf3d0f2b80c80187384bc92c6774d85fefa913ecd941256fff", size = 10508953, upload-time = "2026-03-12T23:05:17.246Z" }, - { url = "https://files.pythonhosted.org/packages/70/60/55bcdc3e9f80bcf39edf0cd272da6fa511a3d94d5a0dd9e0adf76ceebdb4/ruff-0.15.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ee7dcfaad8b282a284df4aa6ddc2741b3f4a18b0555d626805555a820ea181c3", size = 10942257, upload-time = "2026-03-12T23:05:23.076Z" }, - { url = "https://files.pythonhosted.org/packages/e7/f9/005c29bd1726c0f492bfa215e95154cf480574140cb5f867c797c18c790b/ruff-0.15.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3bd9967851a25f038fc8b9ae88a7fbd1b609f30349231dffaa37b6804923c4bb", size = 10322683, upload-time = "2026-03-12T23:05:33.738Z" }, - { url = "https://files.pythonhosted.org/packages/5f/74/2f861f5fd7cbb2146bddb5501450300ce41562da36d21868c69b7a828169/ruff-0.15.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13f4594b04e42cd24a41da653886b04d2ff87adbf57497ed4f728b0e8a4866f8", size = 10660986, upload-time = "2026-03-12T23:05:53.245Z" }, - { url = "https://files.pythonhosted.org/packages/c1/a1/309f2364a424eccb763cdafc49df843c282609f47fe53aa83f38272389e0/ruff-0.15.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e2ed8aea2f3fe57886d3f00ea5b8aae5bf68d5e195f487f037a955ff9fbaac9e", size = 10332177, upload-time = "2026-03-12T23:05:56.145Z" }, - { url = "https://files.pythonhosted.org/packages/30/41/7ebf1d32658b4bab20f8ac80972fb19cd4e2c6b78552be263a680edc55ac/ruff-0.15.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70789d3e7830b848b548aae96766431c0dc01a6c78c13381f423bf7076c66d15", size = 11170783, upload-time = "2026-03-12T23:06:01.742Z" }, - { url = "https://files.pythonhosted.org/packages/76/be/6d488f6adca047df82cd62c304638bcb00821c36bd4881cfca221561fdfc/ruff-0.15.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:542aaf1de3154cea088ced5a819ce872611256ffe2498e750bbae5247a8114e9", size = 12044201, upload-time = "2026-03-12T23:05:28.697Z" }, - { url = "https://files.pythonhosted.org/packages/71/68/e6f125df4af7e6d0b498f8d373274794bc5156b324e8ab4bf5c1b4fc0ec7/ruff-0.15.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c22e6f02c16cfac3888aa636e9eba857254d15bbacc9906c9689fdecb1953ab", size = 11421561, upload-time = "2026-03-12T23:05:31.236Z" }, - { url = "https://files.pythonhosted.org/packages/f1/9f/f85ef5fd01a52e0b472b26dc1b4bd228b8f6f0435975442ffa4741278703/ruff-0.15.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98893c4c0aadc8e448cfa315bd0cc343a5323d740fe5f28ef8a3f9e21b381f7e", size = 11310928, upload-time = "2026-03-12T23:05:45.288Z" }, - { url = "https://files.pythonhosted.org/packages/8c/26/b75f8c421f5654304b89471ed384ae8c7f42b4dff58fa6ce1626d7f2b59a/ruff-0.15.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:70d263770d234912374493e8cc1e7385c5d49376e41dfa51c5c3453169dc581c", size = 11235186, upload-time = "2026-03-12T23:05:50.677Z" }, - { url = "https://files.pythonhosted.org/packages/fc/d4/d5a6d065962ff7a68a86c9b4f5500f7d101a0792078de636526c0edd40da/ruff-0.15.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:55a1ad63c5a6e54b1f21b7514dfadc0c7fb40093fa22e95143cf3f64ebdcd512", size = 10635231, upload-time = "2026-03-12T23:05:37.044Z" }, - { url = "https://files.pythonhosted.org/packages/d6/56/7c3acf3d50910375349016cf33de24be021532042afbed87942858992491/ruff-0.15.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8dc473ba093c5ec238bb1e7429ee676dca24643c471e11fbaa8a857925b061c0", size = 10340357, upload-time = "2026-03-12T23:06:04.748Z" }, - { url = "https://files.pythonhosted.org/packages/06/54/6faa39e9c1033ff6a3b6e76b5df536931cd30caf64988e112bbf91ef5ce5/ruff-0.15.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:85b042377c2a5561131767974617006f99f7e13c63c111b998f29fc1e58a4cfb", size = 10860583, upload-time = "2026-03-12T23:05:58.978Z" }, - { url = "https://files.pythonhosted.org/packages/cb/1e/509a201b843b4dfb0b32acdedf68d951d3377988cae43949ba4c4133a96a/ruff-0.15.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cef49e30bc5a86a6a92098a7fbf6e467a234d90b63305d6f3ec01225a9d092e0", size = 11410976, upload-time = "2026-03-12T23:05:39.955Z" }, - { url = "https://files.pythonhosted.org/packages/6c/25/3fc9114abf979a41673ce877c08016f8e660ad6cf508c3957f537d2e9fa9/ruff-0.15.6-py3-none-win32.whl", hash = "sha256:bbf67d39832404812a2d23020dda68fee7f18ce15654e96fb1d3ad21a5fe436c", size = 10616872, upload-time = "2026-03-12T23:05:42.451Z" }, - { url = "https://files.pythonhosted.org/packages/89/7a/09ece68445ceac348df06e08bf75db72d0e8427765b96c9c0ffabc1be1d9/ruff-0.15.6-py3-none-win_amd64.whl", hash = "sha256:aee25bc84c2f1007ecb5037dff75cef00414fdf17c23f07dc13e577883dca406", size = 11787271, upload-time = "2026-03-12T23:05:20.168Z" }, - { url = "https://files.pythonhosted.org/packages/7f/d0/578c47dd68152ddddddf31cd7fc67dc30b7cdf639a86275fda821b0d9d98/ruff-0.15.6-py3-none-win_arm64.whl", hash = "sha256:c34de3dd0b0ba203be50ae70f5910b17188556630e2178fd7d79fc030eb0d837", size = 11060497, upload-time = "2026-03-12T23:05:25.968Z" }, + { url = "https://files.pythonhosted.org/packages/41/2f/0b08ced94412af091807b6119ca03755d651d3d93a242682bf020189db94/ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e", size = 10489037, upload-time = "2026-03-19T16:26:32.47Z" }, + { url = "https://files.pythonhosted.org/packages/91/4a/82e0fa632e5c8b1eba5ee86ecd929e8ff327bbdbfb3c6ac5d81631bef605/ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477", size = 10955433, upload-time = "2026-03-19T16:27:00.205Z" }, + { url = "https://files.pythonhosted.org/packages/ab/10/12586735d0ff42526ad78c049bf51d7428618c8b5c467e72508c694119df/ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e", size = 10269302, upload-time = "2026-03-19T16:26:26.183Z" }, + { url = "https://files.pythonhosted.org/packages/eb/5d/32b5c44ccf149a26623671df49cbfbd0a0ae511ff3df9d9d2426966a8d57/ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf", size = 10607625, upload-time = "2026-03-19T16:27:03.263Z" }, + { url = "https://files.pythonhosted.org/packages/5d/f1/f0001cabe86173aaacb6eb9bb734aa0605f9a6aa6fa7d43cb49cbc4af9c9/ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85", size = 10324743, upload-time = "2026-03-19T16:27:09.791Z" }, + { url = "https://files.pythonhosted.org/packages/7a/87/b8a8f3d56b8d848008559e7c9d8bf367934d5367f6d932ba779456e2f73b/ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0", size = 11138536, upload-time = "2026-03-19T16:27:06.101Z" }, + { url = "https://files.pythonhosted.org/packages/e4/f2/4fd0d05aab0c5934b2e1464784f85ba2eab9d54bffc53fb5430d1ed8b829/ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912", size = 11994292, upload-time = "2026-03-19T16:26:48.718Z" }, + { url = "https://files.pythonhosted.org/packages/64/22/fc4483871e767e5e95d1622ad83dad5ebb830f762ed0420fde7dfa9d9b08/ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036", size = 11398981, upload-time = "2026-03-19T16:26:54.513Z" }, + { url = "https://files.pythonhosted.org/packages/b0/99/66f0343176d5eab02c3f7fcd2de7a8e0dd7a41f0d982bee56cd1c24db62b/ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5", size = 11242422, upload-time = "2026-03-19T16:26:29.277Z" }, + { url = "https://files.pythonhosted.org/packages/5d/3a/a7060f145bfdcce4c987ea27788b30c60e2c81d6e9a65157ca8afe646328/ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12", size = 11232158, upload-time = "2026-03-19T16:26:42.321Z" }, + { url = "https://files.pythonhosted.org/packages/a7/53/90fbb9e08b29c048c403558d3cdd0adf2668b02ce9d50602452e187cd4af/ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c", size = 10577861, upload-time = "2026-03-19T16:26:57.459Z" }, + { url = "https://files.pythonhosted.org/packages/2f/aa/5f486226538fe4d0f0439e2da1716e1acf895e2a232b26f2459c55f8ddad/ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4", size = 10327310, upload-time = "2026-03-19T16:26:35.909Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/271afdffb81fe7bfc8c43ba079e9d96238f674380099457a74ccb3863857/ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d", size = 10840752, upload-time = "2026-03-19T16:26:45.723Z" }, + { url = "https://files.pythonhosted.org/packages/bf/29/a4ae78394f76c7759953c47884eb44de271b03a66634148d9f7d11e721bd/ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580", size = 11336961, upload-time = "2026-03-19T16:26:39.076Z" }, + { url = "https://files.pythonhosted.org/packages/26/6b/8786ba5736562220d588a2f6653e6c17e90c59ced34a2d7b512ef8956103/ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de", size = 10582538, upload-time = "2026-03-19T16:26:15.992Z" }, + { url = "https://files.pythonhosted.org/packages/2b/e9/346d4d3fffc6871125e877dae8d9a1966b254fbd92a50f8561078b88b099/ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1", size = 11755839, upload-time = "2026-03-19T16:26:19.897Z" }, + { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, ] [[package]] @@ -6105,14 +6040,14 @@ wheels = [ [[package]] name = "scipy-stubs" -version = "1.17.1.2" +version = "1.17.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "optype", extra = ["numpy"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c7/ab/43f681ffba42f363b7ed6b767fd215d1e26006578214ff8330586a11bf95/scipy_stubs-1.17.1.2.tar.gz", hash = "sha256:2ecadc8c87a3b61aaf7379d6d6b10f1038a829c53b9efe5b174fb97fc8b52237", size = 388354, upload-time = "2026-03-15T22:33:20.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/59/59c6cc3f9970154b9ed6b1aff42a0185cdd60cef54adc0404b9e77972221/scipy_stubs-1.17.1.3.tar.gz", hash = "sha256:5eb87a8d23d726706259b012ebe76a4a96a9ae9e141fc59bf55fc8eac2ed9e0f", size = 392185, upload-time = "2026-03-22T22:11:58.34Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/0b/ec4fe720c1202d9df729a3e9d9b7e4d2da9f6e7f28bd2877b7d0769f4f75/scipy_stubs-1.17.1.2-py3-none-any.whl", hash = "sha256:f19e8f5273dbe3b7ee6a9554678c3973b9695fa66b91f29206d00830a1536c06", size = 594377, upload-time = "2026-03-15T22:33:18.684Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d4/94304532c0a75a55526119043dd44a9bd1541a21e14483cbb54261c527d2/scipy_stubs-1.17.1.3-py3-none-any.whl", hash = "sha256:7b91d3f05aa47da06fbca14eb6c5bb4c28994e9245fd250cc847e375bab31297", size = 597933, upload-time = "2026-03-22T22:11:56.525Z" }, ] [[package]] @@ -6131,15 +6066,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.54.0" +version = "2.55.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c8/e9/2e3a46c304e7fa21eaa70612f60354e32699c7102eb961f67448e222ad7c/sentry_sdk-2.54.0.tar.gz", hash = "sha256:2620c2575128d009b11b20f7feb81e4e4e8ae08ec1d36cbc845705060b45cc1b", size = 413813, upload-time = "2026-03-02T15:12:41.355Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/b8/285293dc60fc198fffc3fcdbc7c6d4e646e0f74e61461c355d40faa64ceb/sentry_sdk-2.55.0.tar.gz", hash = "sha256:3774c4d8820720ca4101548131b9c162f4c9426eb7f4d24aca453012a7470f69", size = 424505, upload-time = "2026-03-17T14:15:51.707Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/53/39/be412cc86bc6247b8f69e9383d7950711bd86f8d0a4a4b0fe8fad685bc21/sentry_sdk-2.54.0-py2.py3-none-any.whl", hash = "sha256:fd74e0e281dcda63afff095d23ebcd6e97006102cdc8e78a29f19ecdf796a0de", size = 439198, upload-time = "2026-03-02T15:12:39.546Z" }, + { url = "https://files.pythonhosted.org/packages/9a/66/20465097782d7e1e742d846407ea7262d338c6e876ddddad38ca8907b38f/sentry_sdk-2.55.0-py2.py3-none-any.whl", hash = "sha256:97026981cb15699394474a196b88503a393cbc58d182ece0d3abe12b9bd978d4", size = 449284, upload-time = "2026-03-17T14:15:49.604Z" }, ] [package.optional-dependencies] @@ -6375,15 +6310,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.52.1" +version = "1.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/69/17425771797c36cded50b7fe44e850315d039f28b15901ab44839e70b593/starlette-1.0.0.tar.gz", hash = "sha256:6a4beaf1f81bb472fd19ea9b918b50dc3a77a6f2e190a12954b25e6ed5eea149", size = 2655289, upload-time = "2026-03-22T18:29:46.779Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, + { url = "https://files.pythonhosted.org/packages/0b/c9/584bc9651441b4ba60cc4d557d8a547b5aff901af35bda3a4ee30c819b82/starlette-1.0.0-py3-none-any.whl", hash = "sha256:d3ec55e0bb321692d275455ddfd3df75fff145d009685eb40dc91fc66b03d38b", size = 72651, upload-time = "2026-03-22T18:29:45.111Z" }, ] [[package]] @@ -6792,11 +6727,11 @@ wheels = [ [[package]] name = "types-cachetools" -version = "6.2.0.20251022" +version = "6.2.0.20260317" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3b/a8/f9bcc7f1be63af43ef0170a773e2d88817bcc7c9d8769f2228c802826efe/types_cachetools-6.2.0.20251022.tar.gz", hash = "sha256:f1d3c736f0f741e89ec10f0e1b0138625023e21eb33603a930c149e0318c0cef", size = 9608, upload-time = "2025-10-22T03:03:58.16Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/7f/16a4d8344c28193a5a74358028c2d2f753f0d9658dd98b9e1967c50045a2/types_cachetools-6.2.0.20260317.tar.gz", hash = "sha256:6d91855bcc944665897c125e720aa3c80aace929b77a64e796343701df4f61c6", size = 9812, upload-time = "2026-03-17T04:06:32.007Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/98/2d/8d821ed80f6c2c5b427f650bf4dc25b80676ed63d03388e4b637d2557107/types_cachetools-6.2.0.20251022-py3-none-any.whl", hash = "sha256:698eb17b8f16b661b90624708b6915f33dbac2d185db499ed57e4997e7962cad", size = 9341, upload-time = "2025-10-22T03:03:57.036Z" }, + { url = "https://files.pythonhosted.org/packages/17/9a/b00b23054934c4d569c19f7278c4fb32746cd36a64a175a216d3073a4713/types_cachetools-6.2.0.20260317-py3-none-any.whl", hash = "sha256:92fa9bc50e4629e31fca67ceb3fb1de71791e314fa16c0a0d2728724dc222c8b", size = 9346, upload-time = "2026-03-17T04:06:31.184Z" }, ] [[package]] @@ -6840,11 +6775,11 @@ wheels = [ [[package]] name = "types-docutils" -version = "0.22.3.20260316" +version = "0.22.3.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9f/27/a7f16b3a2fad0a4ddd85a668319f9a1d0311c4bd9578894f6471c7e6c788/types_docutils-0.22.3.20260316.tar.gz", hash = "sha256:8ef27d565b9831ff094fe2eac75337a74151013e2d21ecabd445c2955f891564", size = 57263, upload-time = "2026-03-16T04:29:12.211Z" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bb/243a87fc1605a4a94c2c343d6dbddbf0d7ef7c0b9550f360b8cda8e82c39/types_docutils-0.22.3.20260322.tar.gz", hash = "sha256:e2450bb997283c3141ec5db3e436b91f0aa26efe35eb9165178ca976ccb4930b", size = 57311, upload-time = "2026-03-22T04:08:44.064Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/60/c1f22b7cfc4837d5419e5a2d8702c7d65f03343f866364b71cccd8a73b79/types_docutils-0.22.3.20260316-py3-none-any.whl", hash = "sha256:083c7091b8072c242998ec51da1bf1492f0332387da81c3b085efbf5ca754c7d", size = 91968, upload-time = "2026-03-16T04:29:11.114Z" }, + { url = "https://files.pythonhosted.org/packages/c6/4a/22c090cd4615a16917dff817cbe7c5956da376c961e024c241cd962d2c3d/types_docutils-0.22.3.20260322-py3-none-any.whl", hash = "sha256:681d4510ce9b80a0c6a593f0f9843d81f8caa786db7b39ba04d9fd5480ac4442", size = 91978, upload-time = "2026-03-22T04:08:43.117Z" }, ] [[package]] @@ -6874,15 +6809,15 @@ wheels = [ [[package]] name = "types-gevent" -version = "25.9.0.20251228" +version = "25.9.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-greenlet" }, { name = "types-psutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/85/c5043c4472f82c8ee3d9e0673eb4093c7d16770a26541a137a53a1d096f6/types_gevent-25.9.0.20251228.tar.gz", hash = "sha256:423ef9891d25c5a3af236c3e9aace4c444c86ff773fe13ef22731bc61d59abef", size = 38063, upload-time = "2025-12-28T03:28:28.651Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f0/14a99ddcaa69b559fa7cec8c9de880b792bebb0b848ae865d94ea9058533/types_gevent-25.9.0.20260322.tar.gz", hash = "sha256:91257920845762f09753c08aa20fad1743ac13d2de8bcf23f4b8fe967d803732", size = 38241, upload-time = "2026-03-22T04:08:55.213Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/b7/a2d6b652ab5a26318b68cafd58c46fafb9b15c5313d2d76a70b838febb4b/types_gevent-25.9.0.20251228-py3-none-any.whl", hash = "sha256:e2e225af4fface9241c16044983eb2fc3993f2d13d801f55c2932848649b7f2f", size = 55486, upload-time = "2025-12-28T03:28:27.382Z" }, + { url = "https://files.pythonhosted.org/packages/89/0f/964440b57eb4ddb4aca03479a4093852e1ce79010d1c5967234e6f5d6bd9/types_gevent-25.9.0.20260322-py3-none-any.whl", hash = "sha256:21b3c269b3a20ecb0e4668289c63b97d21694d84a004ab059c1e32ab970eacc2", size = 55500, upload-time = "2026-03-22T04:08:54.103Z" }, ] [[package]] @@ -6965,11 +6900,11 @@ wheels = [ [[package]] name = "types-openpyxl" -version = "3.1.5.20260316" +version = "3.1.5.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/38/32f8ee633dd66ca6d52b8853b9fd45dc3869490195a6ed435d5c868b9c2d/types_openpyxl-3.1.5.20260316.tar.gz", hash = "sha256:081dda9427ea1141e5649e3dcf630e7013a4cf254a5862a7e0a3f53c123b7ceb", size = 101318, upload-time = "2026-03-16T04:29:05.004Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/bf/15240de4d68192d2a1f385ef2f6f1ecb29b85d2f3791dd2e2d5b980be30f/types_openpyxl-3.1.5.20260322.tar.gz", hash = "sha256:a61d66ebe1e49697853c6db8e0929e1cda2c96755e71fb676ed7fc48dfdcf697", size = 101325, upload-time = "2026-03-22T04:08:40.426Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/df/b87ae6226ed7cc84b9e43119c489c7f053a9a25e209e0ebb5d84bc36fa37/types_openpyxl-3.1.5.20260316-py3-none-any.whl", hash = "sha256:38e7e125df520fb7eb72cb1129c9f024eb99ef9564aad2c27f68f080c26bcf2d", size = 166084, upload-time = "2026-03-16T04:29:03.657Z" }, + { url = "https://files.pythonhosted.org/packages/bf/b4/c14191b30bcb266365b124b2bb4e67ecd68425a78ba77ee026f33667daa9/types_openpyxl-3.1.5.20260322-py3-none-any.whl", hash = "sha256:2f515f0b0bbfb04bfb587de34f7522d90b5151a8da7bbbd11ecec4ca40f64238", size = 166102, upload-time = "2026-03-22T04:08:39.174Z" }, ] [[package]] @@ -7044,11 +6979,11 @@ wheels = [ [[package]] name = "types-python-dateutil" -version = "2.9.0.20260305" +version = "2.9.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/c7/025c624f347e10476b439a6619a95f1d200250ea88e7ccea6e09e48a7544/types_python_dateutil-2.9.0.20260305.tar.gz", hash = "sha256:389717c9f64d8f769f36d55a01873915b37e97e52ce21928198d210fbd393c8b", size = 16885, upload-time = "2026-03-05T04:00:47.409Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/02/f72df9ef5ffc4f959b83cb80c8aa03eb8718a43e563ecd99ccffe265fa89/types_python_dateutil-2.9.0.20260323.tar.gz", hash = "sha256:a107aef5841db41ace381dbbbd7e4945220fc940f7a72172a0be5a92d9ab7164", size = 16897, upload-time = "2026-03-23T04:15:14.829Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/77/8c0d1ec97f0d9707ad3d8fa270ab8964e7b31b076d2f641c94987395cc75/types_python_dateutil-2.9.0.20260305-py3-none-any.whl", hash = "sha256:a3be9ca444d38cadabd756cfbb29780d8b338ae2a3020e73c266a83cc3025dd7", size = 18419, upload-time = "2026-03-05T04:00:46.392Z" }, + { url = "https://files.pythonhosted.org/packages/92/c1/b661838b97453e699a215451f2e22cee750eaaf4ea4619b34bdaf01221a4/types_python_dateutil-2.9.0.20260323-py3-none-any.whl", hash = "sha256:a23a50a07f6eb87e729d4cb0c2eb511c81761eeb3f505db2c1413be94aae8335", size = 18433, upload-time = "2026-03-23T04:15:13.683Z" }, ] [[package]] @@ -7062,11 +6997,11 @@ wheels = [ [[package]] name = "types-pywin32" -version = "311.0.0.20260316" +version = "311.0.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/17/a8/b4652002a854fcfe5d272872a0ae2d5df0e9dc482e1a6dfb5e97b905b76f/types_pywin32-311.0.0.20260316.tar.gz", hash = "sha256:c136fa489fe6279a13bca167b750414e18d657169b7cf398025856dc363004e8", size = 329956, upload-time = "2026-03-16T04:28:57.366Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/cc/f03ddb7412ac2fc2238358b617c2d5919ba96812dff8d3081f3b2754bb83/types_pywin32-311.0.0.20260323.tar.gz", hash = "sha256:2e8dc6a59fedccbc51b241651ce1e8aa58488934f517debf23a9c6d0ff329b4b", size = 332263, upload-time = "2026-03-23T04:15:20.004Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/83/704698d93788cf1c2f5e236eae2b37f1b2152ef84dc66b4b83f6c7487b76/types_pywin32-311.0.0.20260316-py3-none-any.whl", hash = "sha256:abb643d50012386d697af49384cc0e6e475eab76b0ca2a7f93d480d0862b3692", size = 392959, upload-time = "2026-03-16T04:28:56.104Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/d786d5d8b846e3cbe1ee52da8945560b111c789b42c3771b2129b312ab94/types_pywin32-311.0.0.20260323-py3-none-any.whl", hash = "sha256:2f2b03fc72ae77ccbb0ee258da0f181c3a38bd8602f6e332e42587b3b0d5f095", size = 395435, upload-time = "2026-03-23T04:15:18.76Z" }, ] [[package]] @@ -7162,16 +7097,16 @@ wheels = [ [[package]] name = "types-tensorflow" -version = "2.18.0.20260224" +version = "2.18.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "types-protobuf" }, { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/cb/4914c2fbc1cf8a8d1ef2a7c727bb6f694879be85edeee880a0c88e696af8/types_tensorflow-2.18.0.20260224.tar.gz", hash = "sha256:9b0ccc91c79c88791e43d3f80d6c879748fa0361409c5ff23c7ffe3709be00f2", size = 258786, upload-time = "2026-02-24T04:06:45.613Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/81dfaa2680031a6e087bcdfaf1c0556371098e229aee541e21c81a381065/types_tensorflow-2.18.0.20260322.tar.gz", hash = "sha256:135dc6ca06cc647a002e1bca5c5c99516fde51efd08e46c48a9b1916fc5df07f", size = 259030, upload-time = "2026-03-22T04:09:14.069Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/1d/a1c3c60f0eb1a204500dbdc66e3d18aafabc86ad07a8eca71ea05bc8c5a8/types_tensorflow-2.18.0.20260224-py3-none-any.whl", hash = "sha256:6a25f5f41f3e06f28c1f65c6e09f484d4ba0031d6d8df83a39df9d890245eefc", size = 329746, upload-time = "2026-02-24T04:06:44.4Z" }, + { url = "https://files.pythonhosted.org/packages/5b/0c/a178061450b640e53577e2c423ad22bf5d3f692f6bfeeb12156d02b531ef/types_tensorflow-2.18.0.20260322-py3-none-any.whl", hash = "sha256:d8776b6daacdb279e64f105f9dcbc0b8e3544b9a2f2eb71ec6ea5955081f65e6", size = 329771, upload-time = "2026-03-22T04:09:12.844Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index 9d6cd65318..8cf77cf56b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -771,6 +771,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # VikingDB configurations, only available when VECTOR_STORE is `vikingdb` VIKINGDB_ACCESS_KEY=your-ak diff --git a/docker/dify-env-sync.py b/docker/dify-env-sync.py new file mode 100755 index 0000000000..d7c762748c --- /dev/null +++ b/docker/dify-env-sync.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 + +# ================================================================ +# Dify Environment Variables Synchronization Script +# +# Features: +# - Synchronize latest settings from .env.example to .env +# - Preserve custom settings in existing .env +# - Add new environment variables +# - Detect removed environment variables +# - Create backup files +# ================================================================ + +import argparse +import re +import shutil +import sys +from datetime import datetime +from pathlib import Path + +# ANSI color codes +RED = "\033[0;31m" +GREEN = "\033[0;32m" +YELLOW = "\033[1;33m" +BLUE = "\033[0;34m" +NC = "\033[0m" # No Color + + +def supports_color() -> bool: + """Return True if the terminal supports ANSI color codes.""" + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + + +def log_info(message: str) -> None: + """Print an informational message in blue.""" + if supports_color(): + print(f"{BLUE}[INFO]{NC} {message}") + else: + print(f"[INFO] {message}") + + +def log_success(message: str) -> None: + """Print a success message in green.""" + if supports_color(): + print(f"{GREEN}[SUCCESS]{NC} {message}") + else: + print(f"[SUCCESS] {message}") + + +def log_warning(message: str) -> None: + """Print a warning message in yellow to stderr.""" + if supports_color(): + print(f"{YELLOW}[WARNING]{NC} {message}", file=sys.stderr) + else: + print(f"[WARNING] {message}", file=sys.stderr) + + +def log_error(message: str) -> None: + """Print an error message in red to stderr.""" + if supports_color(): + print(f"{RED}[ERROR]{NC} {message}", file=sys.stderr) + else: + print(f"[ERROR] {message}", file=sys.stderr) + + +def parse_env_file(path: Path) -> dict[str, str]: + """Parse an .env-style file and return a mapping of key to raw value. + + Lines that are blank or start with '#' (after optional whitespace) are + skipped. Only lines containing '=' are considered variable definitions. + + Args: + path: Path to the .env file to parse. + + Returns: + Ordered dict mapping variable name to its value string. + """ + variables: dict[str, str] = {} + with path.open(encoding="utf-8") as fh: + for line in fh: + line = line.rstrip("\n") + # Skip blank lines and comment lines + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + if key: + variables[key] = value.strip() + return variables + + +def check_files(work_dir: Path) -> None: + """Verify required files exist; create .env from .env.example if absent. + + Args: + work_dir: Directory that must contain .env.example (and optionally .env). + + Raises: + SystemExit: If .env.example does not exist. + """ + log_info("Checking required files...") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + if not example_file.exists(): + log_error(".env.example file not found") + sys.exit(1) + + if not env_file.exists(): + log_warning(".env file does not exist. Creating from .env.example.") + shutil.copy2(example_file, env_file) + log_success(".env file created") + + log_success("Required files verified") + + +def create_backup(work_dir: Path) -> None: + """Create a timestamped backup of the current .env file. + + Backups are placed in ``/env-backup/`` with the filename + ``.env.backup_``. + + Args: + work_dir: Directory containing the .env file to back up. + """ + env_file = work_dir / ".env" + if not env_file.exists(): + return + + backup_dir = work_dir / "env-backup" + if not backup_dir.exists(): + backup_dir.mkdir(parents=True) + log_info(f"Created backup directory: {backup_dir}") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = backup_dir / f".env.backup_{timestamp}" + shutil.copy2(env_file, backup_file) + log_success(f"Backed up existing .env to {backup_file}") + + +def analyze_value_change(current: str, recommended: str) -> str | None: + """Analyse what kind of change occurred between two env values. + + Args: + current: Value currently set in .env. + recommended: Value present in .env.example. + + Returns: + A human-readable description string, or None when no analysis applies. + """ + use_colors = supports_color() + + def colorize(color: str, text: str) -> str: + return f"{color}{text}{NC}" if use_colors else text + + if not current and recommended: + return colorize(RED, " -> Setting from empty to recommended value") + if current and not recommended: + return colorize(RED, " -> Recommended value changed to empty") + + # Numeric comparison + if re.fullmatch(r"\d+", current) and re.fullmatch(r"\d+", recommended): + cur_int, rec_int = int(current), int(recommended) + if cur_int < rec_int: + return colorize(BLUE, f" -> Numeric increase ({current} < {recommended})") + if cur_int > rec_int: + return colorize(YELLOW, f" -> Numeric decrease ({current} > {recommended})") + return None + + # Boolean comparison + if current.lower() in {"true", "false"} and recommended.lower() in {"true", "false"}: + if current.lower() != recommended.lower(): + return colorize(BLUE, f" -> Boolean value change ({current} -> {recommended})") + return None + + # URL / endpoint + if current.startswith(("http://", "https://")) or recommended.startswith(("http://", "https://")): + return colorize(BLUE, " -> URL/endpoint change") + + # File path + if current.startswith("/") or recommended.startswith("/"): + return colorize(BLUE, " -> File path change") + + # String length + if len(current) != len(recommended): + return colorize(YELLOW, f" -> String length change ({len(current)} -> {len(recommended)} characters)") + + return None + + +def detect_differences(env_vars: dict[str, str], example_vars: dict[str, str]) -> dict[str, tuple[str, str]]: + """Find variables whose values differ between .env and .env.example. + + Only variables present in *both* files are compared; new or removed + variables are handled by separate functions. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Mapping of key -> (env_value, example_value) for every key whose + values differ. + """ + log_info("Detecting differences between .env and .env.example...") + + diffs: dict[str, tuple[str, str]] = {} + for key, example_value in example_vars.items(): + if key in env_vars and env_vars[key] != example_value: + diffs[key] = (env_vars[key], example_value) + + if diffs: + log_success(f"Detected differences in {len(diffs)} environment variables") + show_differences_detail(diffs) + else: + log_info("No differences detected") + + return diffs + + +def show_differences_detail(diffs: dict[str, tuple[str, str]]) -> None: + """Print a formatted table of differing environment variables. + + Args: + diffs: Mapping of key -> (current_value, recommended_value). + """ + use_colors = supports_color() + + log_info("") + log_info("=== Environment Variable Differences ===") + + if not diffs: + log_info("No differences to display") + return + + for count, (key, (env_value, example_value)) in enumerate(diffs.items(), start=1): + print() + if use_colors: + print(f"{YELLOW}[{count}] {key}{NC}") + print(f" {GREEN}.env (current){NC} : {env_value}") + print(f" {BLUE}.env.example (recommended){NC} : {example_value}") + else: + print(f"[{count}] {key}") + print(f" .env (current) : {env_value}") + print(f" .env.example (recommended) : {example_value}") + + analysis = analyze_value_change(env_value, example_value) + if analysis: + print(analysis) + + print() + log_info("=== Difference Analysis Complete ===") + log_info("Note: Consider changing to the recommended values above.") + log_info("Current implementation preserves .env values.") + print() + + +def detect_removed_variables(env_vars: dict[str, str], example_vars: dict[str, str]) -> list[str]: + """Identify variables present in .env but absent from .env.example. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Sorted list of variable names that no longer appear in .env.example. + """ + log_info("Detecting removed environment variables...") + + removed = sorted(set(env_vars) - set(example_vars)) + + if removed: + log_warning("The following environment variables have been removed from .env.example:") + for var in removed: + log_warning(f" - {var}") + log_warning("Consider manually removing these variables from .env") + else: + log_success("No removed environment variables found") + + return removed + + +def sync_env_file(work_dir: Path, env_vars: dict[str, str], diffs: dict[str, tuple[str, str]]) -> None: + """Rewrite .env based on .env.example while preserving custom values. + + The output file follows the exact line structure of .env.example + (preserving comments, blank lines, and ordering). For every variable + that exists in .env with a different value from the example, the + current .env value is kept. Variables that are new in .env.example + (not present in .env at all) are added with the example's default. + + Args: + work_dir: Directory containing .env and .env.example. + env_vars: Parsed key/value pairs from the original .env. + diffs: Keys whose .env values differ from .env.example (to preserve). + """ + log_info("Starting partial synchronization of .env file...") + + example_file = work_dir / ".env.example" + new_env_file = work_dir / ".env.new" + + # Keys whose current .env value should override the example default + preserved_keys: set[str] = set(diffs.keys()) + + preserved_count = 0 + updated_count = 0 + + env_var_pattern = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=") + + with example_file.open(encoding="utf-8") as src, new_env_file.open("w", encoding="utf-8") as dst: + for line in src: + raw_line = line.rstrip("\n") + match = env_var_pattern.match(raw_line) + if match: + key = match.group(1) + if key in preserved_keys: + # Write the preserved value from .env + dst.write(f"{key}={env_vars[key]}\n") + log_info(f" Preserved: {key} (.env value)") + preserved_count += 1 + else: + # Use the example value (covers new vars and unchanged ones) + dst.write(line if line.endswith("\n") else raw_line + "\n") + updated_count += 1 + else: + # Blank line, comment, or non-variable line — keep as-is + dst.write(line if line.endswith("\n") else raw_line + "\n") + + # Atomically replace the original .env + try: + new_env_file.replace(work_dir / ".env") + except OSError as exc: + log_error(f"Failed to replace .env file: {exc}") + new_env_file.unlink(missing_ok=True) + sys.exit(1) + + log_success("Successfully created new .env file") + log_success("Partial synchronization of .env file completed") + log_info(f" Preserved .env values: {preserved_count}") + log_info(f" Updated to .env.example values: {updated_count}") + + +def show_statistics(work_dir: Path) -> None: + """Print a summary of variable counts from both env files. + + Args: + work_dir: Directory containing .env and .env.example. + """ + log_info("Synchronization statistics:") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + example_count = len(parse_env_file(example_file)) if example_file.exists() else 0 + env_count = len(parse_env_file(env_file)) if env_file.exists() else 0 + + log_info(f" .env.example environment variables: {example_count}") + log_info(f" .env environment variables: {env_count}") + + +def build_arg_parser() -> argparse.ArgumentParser: + """Build and return the CLI argument parser. + + Returns: + Configured ArgumentParser instance. + """ + parser = argparse.ArgumentParser( + prog="dify-env-sync", + description=( + "Synchronize .env with .env.example: add new variables, " + "preserve custom values, and report removed variables." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Examples:\n" + " # Run from the docker/ directory (default)\n" + " python dify-env-sync.py\n\n" + " # Specify a custom working directory\n" + " python dify-env-sync.py --dir /path/to/docker\n" + ), + ) + parser.add_argument( + "--dir", + metavar="DIRECTORY", + default=".", + help="Working directory containing .env and .env.example (default: current directory)", + ) + parser.add_argument( + "--no-backup", + action="store_true", + default=False, + help="Skip creating a timestamped backup of the existing .env file", + ) + return parser + + +def main() -> None: + """Orchestrate the complete environment variable synchronization process.""" + parser = build_arg_parser() + args = parser.parse_args() + + work_dir = Path(args.dir).resolve() + + log_info("=== Dify Environment Variables Synchronization Script ===") + log_info(f"Execution started: {datetime.now()}") + log_info(f"Working directory: {work_dir}") + + # 1. Verify prerequisites + check_files(work_dir) + + # 2. Backup existing .env + if not args.no_backup: + create_backup(work_dir) + + # 3. Parse both files + env_vars = parse_env_file(work_dir / ".env") + example_vars = parse_env_file(work_dir / ".env.example") + + # 4. Report differences (values that changed in the example) + diffs = detect_differences(env_vars, example_vars) + + # 5. Report variables removed from the example + detect_removed_variables(env_vars, example_vars) + + # 6. Rewrite .env + sync_env_file(work_dir, env_vars, diffs) + + # 7. Print summary statistics + show_statistics(work_dir) + + log_success("=== Synchronization process completed successfully ===") + log_info(f"Execution finished: {datetime.now()}") + + +if __name__ == "__main__": + main() diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 04bd2858ff..98c2613a07 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.2 + image: langgenius/dify-web:1.13.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -245,7 +245,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always environment: # The DifySandbox configurations @@ -269,7 +269,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 2dca581903..1746bb567a 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -97,7 +97,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always env_file: - ./middleware.env @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index bf72a0f623..2a75de1a89 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -345,6 +345,9 @@ x-shared-env: &shared-api-worker-env BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: ${BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER:-DEFAULT_ANALYZER} BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: ${BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE:-COARSE_MODE} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT:-500} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO:-0.05} + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: ${BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS:-300} VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} @@ -728,7 +731,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -770,7 +773,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -809,7 +812,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -839,7 +842,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.2 + image: langgenius/dify-web:1.13.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -952,7 +955,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always environment: # The DifySandbox configurations @@ -976,7 +979,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. diff --git a/docker/ssrf_proxy/squid.conf.template b/docker/ssrf_proxy/squid.conf.template index 256e669c8d..fbe9ebc448 100644 --- a/docker/ssrf_proxy/squid.conf.template +++ b/docker/ssrf_proxy/squid.conf.template @@ -28,6 +28,7 @@ http_access deny manager http_access allow localhost include /etc/squid/conf.d/*.conf http_access deny all +tcp_outgoing_address 0.0.0.0 ################################## Proxy Server ################################ http_port ${HTTP_PORT} diff --git a/docker/volumes/sandbox/conf/config.yaml b/docker/volumes/sandbox/conf/config.yaml index 8c1a1deb54..3b4a6b8439 100644 --- a/docker/volumes/sandbox/conf/config.yaml +++ b/docker/volumes/sandbox/conf/config.yaml @@ -5,7 +5,8 @@ app: max_workers: 4 max_requests: 50 worker_timeout: 5 -python_path: /usr/local/bin/python3 +python_path: /opt/python/bin/python3 +nodejs_path: /usr/local/bin/node enable_network: True # please make sure there is no network risk in your environment allowed_syscalls: # please leave it empty if you have no idea how seccomp works proxy: diff --git a/docker/volumes/sandbox/conf/config.yaml.example b/docker/volumes/sandbox/conf/config.yaml.example index f92c19e51a..365089cb9e 100644 --- a/docker/volumes/sandbox/conf/config.yaml.example +++ b/docker/volumes/sandbox/conf/config.yaml.example @@ -5,7 +5,7 @@ app: max_workers: 4 max_requests: 50 worker_timeout: 5 -python_path: /usr/local/bin/python3 +python_path: /opt/python/bin/python3 python_lib_path: - /usr/local/lib/python3.10 - /usr/lib/python3.10 diff --git a/docs/eu-ai-act-compliance.md b/docs/eu-ai-act-compliance.md new file mode 100644 index 0000000000..5fa29eed3f --- /dev/null +++ b/docs/eu-ai-act-compliance.md @@ -0,0 +1,186 @@ +# EU AI Act Compliance Guide for Dify Deployers + +Dify is an LLMOps platform for building RAG pipelines, agents, and AI workflows. If you deploy Dify in the EU — whether self-hosted or using a cloud provider — the EU AI Act applies to your deployment. This guide covers what the regulation requires and how Dify's architecture maps to those requirements. + +## Is your system in scope? + +The detailed obligations in Articles 12, 13, and 14 only apply to **high-risk AI systems** as defined in Annex III of the EU AI Act. A Dify application is high-risk if it is used for: + +- **Recruitment and HR** — screening candidates, evaluating employee performance, allocating tasks +- **Credit scoring and insurance** — assessing creditworthiness or setting premiums +- **Law enforcement** — profiling, criminal risk assessment, border control +- **Critical infrastructure** — managing energy, water, transport, or telecommunications systems +- **Education assessment** — grading students, determining admissions +- **Essential public services** — evaluating eligibility for benefits, housing, or emergency services + +Most Dify deployments (customer-facing chatbots, internal knowledge bases, content generation workflows) are **not** high-risk. If your Dify application does not fall into one of the categories above: + +- **Article 50** (end-user transparency) still applies if users interact with your application directly. See the [Article 50 section](#article-50-end-user-transparency) below. +- **GDPR** still applies if you process personal data. See the [GDPR section](#gdpr-considerations) below. +- The high-risk obligations (Articles 9-15) are less likely to apply, but risk classification is context-dependent. **Do not self-classify without legal review.** Focus on Article 50 (transparency) and GDPR (data protection) as your baseline obligations. + +If you are unsure whether your use case qualifies as high-risk, consult a qualified legal professional before proceeding. + +## Self-hosted vs cloud: different compliance profiles + +| Deployment | Your role | Dify's role | Who handles compliance? | +|-----------|----------|-------------|------------------------| +| **Self-hosted** | Provider and deployer | Framework provider — obligations under Article 25 apply only if Dify is placed on the market or put into service as part of a complete AI system bearing its name or trademark | You | +| **Dify Cloud** | Deployer | Provider and processor | Shared — Dify handles SOC 2 and GDPR for the platform; you handle AI Act obligations for your specific use case | + +Dify Cloud already has SOC 2 Type II and GDPR compliance for the platform itself. But the EU AI Act adds obligations specific to AI systems that SOC 2 does not cover: risk classification, technical documentation, transparency, and human oversight. + +## Supported providers and services + +Dify integrates with a broad range of AI providers and data stores. The following are the key ones relevant to compliance: + +- **AI providers:** HuggingFace (core), plus integrations with OpenAI, Anthropic, Google, and 100+ models via provider plugins +- **Model identifiers include:** gpt-4o, gpt-3.5-turbo, claude-3-opus, gemini-2.5-flash, whisper-1, and others +- **Vector database connections:** Extensive RAG infrastructure supporting numerous vector stores + +Dify's plugin architecture means actual provider usage depends on your configuration. Document which providers and models are active in your deployment. + +## Data flow diagram + +A typical Dify RAG deployment: + +```mermaid +graph LR + USER((User)) -->|query| DIFY[Dify Platform] + DIFY -->|prompts| LLM([LLM Provider]) + LLM -->|responses| DIFY + DIFY -->|documents| EMBED([Embedding Model]) + EMBED -->|vectors| DIFY + DIFY -->|store/retrieve| VS[(Vector Store)] + DIFY -->|knowledge| KB[(Knowledge Base)] + DIFY -->|response| USER + + classDef processor fill:#60a5fa,stroke:#1e40af,color:#000 + classDef controller fill:#4ade80,stroke:#166534,color:#000 + classDef app fill:#a78bfa,stroke:#5b21b6,color:#000 + classDef user fill:#f472b6,stroke:#be185d,color:#000 + + class USER user + class DIFY app + class LLM processor + class EMBED processor + class VS controller + class KB controller +``` + +**GDPR roles** (providers are typically processors for customer-submitted data, but the exact role depends on each provider's terms of service and processing purpose; deployers should review each provider's DPA): +- **Cloud LLM providers (OpenAI, Anthropic, Google)** typically act as processors — requires DPA. +- **Cloud embedding services** typically act as processors — requires DPA. +- **Self-hosted vector stores (Weaviate, Qdrant, pgvector):** Your organization remains the controller — no third-party transfer. +- **Cloud vector stores (Pinecone, Zilliz Cloud)** typically act as processors — requires DPA. +- **Knowledge base documents:** Your organization is the controller — stored in your infrastructure. + +## Article 11: Technical documentation + +High-risk systems need Annex IV documentation. For Dify deployments, key sections include: + +| Section | What Dify provides | What you must document | +|---------|-------------------|----------------------| +| General description | Platform capabilities, supported models | Your specific use case, intended users, deployment context | +| Development process | Dify's architecture, plugin system | Your RAG pipeline design, prompt engineering, knowledge base curation | +| Monitoring | Dify's built-in logging and analytics | Your monitoring plan, alert thresholds, incident response | +| Performance metrics | Dify's evaluation features | Your accuracy benchmarks, quality thresholds, bias testing | +| Risk management | — | Risk assessment for your specific use case | + +Some sections can be derived from Dify's architecture and your deployment configuration, as shown in the table above. The remaining sections require your input. + +## Article 12: Record-keeping + +Dify's built-in logging covers several Article 12 requirements: + +| Requirement | Dify Feature | Status | +|------------|-------------|--------| +| Conversation logs | Full conversation history with timestamps | **Covered** | +| Model tracking | Model name recorded per interaction | **Covered** | +| Token usage | Token counts per message | **Covered** | +| Cost tracking | Cost per conversation (if provider reports it) | **Partial** | +| Document retrieval | RAG source documents logged | **Covered** | +| User identification | User session tracking | **Covered** | +| Error logging | Failed generation logs | **Covered** | +| Data retention | Configurable | **Your responsibility** | + +**Retention periods:** The required retention period depends on your role under the Act. Article 18 requires **providers** of high-risk systems to retain logs and technical documentation for **10 years** after market placement. Article 26(6) requires **deployers** to retain logs for at least **6 months**. If you self-host Dify and have substantially modified the system, you may be classified as a provider rather than a deployer. Confirm the applicable retention period with legal counsel. + +## Article 13: Transparency to deployers + +Article 13 requires providers of high-risk AI systems to supply deployers with the information needed to understand and operate the system correctly. This is a **documentation obligation**, not a logging obligation. For Dify deployments, this means the upstream LLM and embedding providers must give you: + +- Instructions for use, including intended purpose and known limitations +- Accuracy metrics and performance benchmarks +- Known or foreseeable risks and residual risks after mitigation +- Technical specifications: input/output formats, training data characteristics, model architecture details + +As a deployer, collect model cards, system documentation, and accuracy reports from each AI provider your Dify application uses. Maintain these as part of your Annex IV technical documentation. + +Dify's platform features provide **supporting evidence** that can inform Article 13 documentation, but they do not satisfy Article 13 on their own: +- **Source attribution** — Dify's RAG citation feature shows which documents informed the response, supporting deployer-side auditing +- **Model identification** — Dify logs which LLM model generates responses, providing evidence for system documentation +- **Conversation logs** — execution history helps compile performance and behavior evidence + +You must independently produce system documentation covering how your specific Dify deployment uses AI, its intended purpose, performance characteristics, and residual risks. + +## Article 50: End-user transparency + +Article 50 requires deployers to inform end users that they are interacting with an AI system. This is a separate obligation from Article 13 and applies even to limited-risk systems. + +For Dify applications serving end users: + +1. **Disclose AI involvement** — tell users they are interacting with an AI system +2. **AI-generated content labeling** — identify AI-generated content as such (e.g., clear labeling in the UI) + +Dify's "citation" feature also supports end-user transparency by showing users which knowledge base documents informed the answer. + +> **Note:** Article 50 applies to chatbots and systems interacting directly with natural persons. It has a separate scope from the high-risk designation under Annex III — it applies even to limited-risk systems. + +## Article 14: Human oversight + +Article 14 requires that high-risk AI systems be designed so that natural persons can effectively oversee them. Dify provides **automated technical safeguards** that support human oversight, but they are not a substitute for it: + +| Dify Feature | What It Does | Oversight Role | +|-------------|-------------|----------------| +| Annotation/feedback system | Human review of AI outputs | **Direct oversight** — humans evaluate and correct AI responses | +| Content moderation | Built-in filtering before responses reach users | **Automated safeguard** — reduces harmful outputs but does not replace human judgment on edge cases | +| Rate limiting | Controls on API usage | **Automated safeguard** — bounds system behavior, supports overseer's ability to maintain control | +| Workflow control | Insert human review steps between AI generation and output | **Oversight enabler** — allows building approval gates into the pipeline | + +These automated controls are necessary building blocks, but Article 14 compliance requires **human oversight procedures** on top of them: +- **Escalation procedures** — define what happens when moderation triggers or edge cases arise (who is notified, what action is taken) +- **Human review pipeline** — for high-stakes decisions, route AI outputs to a qualified person before they take effect +- **Override mechanism** — a human must be able to halt AI responses or override the system's output +- **Competence requirements** — the human overseer must understand the system's capabilities, limitations, and the context of its outputs + +### Recommended pattern + +For high-risk use cases (HR, legal, medical), configure your Dify workflow to require human approval before the AI response is delivered to the end user or acted upon. + +## Knowledge base compliance + +Dify's knowledge base feature has specific compliance implications: + +1. **Data provenance:** Document where your knowledge base documents come from. Article 10 requires data governance for training data; knowledge bases are analogous. +2. **Update tracking:** When you add, remove, or update documents in the knowledge base, log the change. The AI system's behavior changes with its knowledge base. +3. **PII in documents:** If knowledge base documents contain personal data, GDPR applies to the entire RAG pipeline. Implement access controls and consider PII redaction before indexing. +4. **Copyright:** Ensure you have the right to use the documents in your knowledge base for AI-assisted generation. + +## GDPR considerations + +1. **Legal basis** (Article 6): Document why AI processing of user queries is necessary +2. **Data Processing Agreements** (Article 28): Required for each cloud LLM and embedding provider +3. **Data minimization:** Only include necessary context in prompts; avoid sending entire documents when a relevant excerpt suffices +4. **Right to erasure:** If a user requests deletion, ensure their conversations are removed from Dify's logs AND any vector store entries derived from their data +5. **Cross-border transfers:** Providers based outside the EEA — including US-based providers (OpenAI, Anthropic), and any other non-EEA providers you route to — require Standard Contractual Clauses (SCCs) or equivalent safeguards under Chapter V of the GDPR. Review each provider's transfer mechanism individually. + +## Resources + +- [EU AI Act full text](https://artificialintelligenceact.eu/) +- [Dify documentation](https://docs.dify.ai/) +- [Dify SOC 2 compliance](https://dify.ai/trust) + +--- + +*This is not legal advice. Consult a qualified professional for compliance decisions.* diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 7c8a293446..7168d33c24 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -69,7 +69,9 @@ }, "pnpm": { "overrides": { - "rollup@>=4.0.0,<4.59.0": "4.59.0" + "flatted@<=3.4.1": "3.4.2", + "picomatch@>=4.0.0 <4.0.4": "4.0.4", + "rollup@>=4.0.0 <4.59.0": "4.59.0" } } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index c4b299cd73..30d3cf61ee 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -5,7 +5,9 @@ settings: excludeLinksFromLockfile: false overrides: - rollup@>=4.0.0,<4.59.0: 4.59.0 + flatted@<=3.4.1: 3.4.2 + picomatch@>=4.0.0 <4.0.4: 4.0.4 + rollup@>=4.0.0 <4.59.0: 4.59.0 importers: @@ -721,7 +723,7 @@ packages: resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} engines: {node: '>=12.0.0'} peerDependencies: - picomatch: ^3 || ^4 + picomatch: 4.0.4 peerDependenciesMeta: picomatch: optional: true @@ -741,8 +743,8 @@ packages: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} - flatted@3.4.1: - resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} + flatted@3.4.2: + resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==} follow-redirects@1.15.11: resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} @@ -949,8 +951,8 @@ packages: picocolors@1.1.1: resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} - picomatch@4.0.3: - resolution: {integrity: sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==} + picomatch@4.0.4: + resolution: {integrity: sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==} engines: {node: '>=12'} pirates@4.0.7: @@ -1815,9 +1817,9 @@ snapshots: fast-levenshtein@2.0.6: {} - fdir@6.5.0(picomatch@4.0.3): + fdir@6.5.0(picomatch@4.0.4): optionalDependencies: - picomatch: 4.0.3 + picomatch: 4.0.4 file-entry-cache@8.0.0: dependencies: @@ -1836,10 +1838,10 @@ snapshots: flat-cache@4.0.1: dependencies: - flatted: 3.4.1 + flatted: 3.4.2 keyv: 4.5.4 - flatted@3.4.1: {} + flatted@3.4.2: {} follow-redirects@1.15.11: {} @@ -2024,7 +2026,7 @@ snapshots: picocolors@1.1.1: {} - picomatch@4.0.3: {} + picomatch@4.0.4: {} pirates@4.0.7: {} @@ -2135,8 +2137,8 @@ snapshots: tinyglobby@0.2.15: dependencies: - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 tinyrainbow@3.0.3: {} @@ -2193,8 +2195,8 @@ snapshots: vite@7.3.1(@types/node@25.4.0): dependencies: esbuild: 0.27.3 - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 postcss: 8.5.8 rollup: 4.59.0 tinyglobby: 0.2.15 @@ -2216,7 +2218,7 @@ snapshots: magic-string: 0.30.21 obug: 2.1.1 pathe: 2.0.3 - picomatch: 4.0.3 + picomatch: 4.0.4 std-env: 3.10.0 tinybench: 2.9.0 tinyexec: 1.0.2 diff --git a/web/.env.example b/web/.env.example index 079c3bdeef..62d4fa6c56 100644 --- a/web/.env.example +++ b/web/.env.example @@ -51,8 +51,6 @@ NEXT_PUBLIC_ALLOW_EMBED= # Allow rendering unsafe URLs which have "data:" scheme. NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false -# Github Access Token, used for invoking Github API -NEXT_PUBLIC_GITHUB_ACCESS_TOKEN= # The maximum number of top-k value for RAG. NEXT_PUBLIC_TOP_K_MAX_VALUE=10 diff --git a/web/.storybook/preview.tsx b/web/.storybook/preview.tsx index 5b38424776..072244c33f 100644 --- a/web/.storybook/preview.tsx +++ b/web/.storybook/preview.tsx @@ -7,7 +7,7 @@ import { I18nClientProvider as I18N } from '../app/components/provider/i18n' import commonEnUS from '../i18n/en-US/common.json' import '../app/styles/globals.css' -import '../app/styles/markdown.scss' +import '../app/styles/markdown.css' import './storybook.css' const queryClient = new QueryClient({ diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index de78ae997e..9e9b3d7168 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -774,7 +774,7 @@ export default translation` const endTime = Date.now() expect(keys.length).toBe(1000) - expect(endTime - startTime).toBeLessThan(1000) // Should complete in under 1 second + expect(endTime - startTime).toBeLessThan(10000) }) it('should handle multiple translation files concurrently', async () => { @@ -796,7 +796,7 @@ export default translation` const endTime = Date.now() expect(keys.length).toBe(20) // 10 files * 2 keys each - expect(endTime - startTime).toBeLessThan(500) + expect(endTime - startTime).toBeLessThan(10000) }) }) diff --git a/web/__tests__/datasets/dataset-settings-flow.test.tsx b/web/__tests__/datasets/dataset-settings-flow.test.tsx index 607cd8c2d5..b4a5e78326 100644 --- a/web/__tests__/datasets/dataset-settings-flow.test.tsx +++ b/web/__tests__/datasets/dataset-settings-flow.test.tsx @@ -19,6 +19,10 @@ import { RETRIEVE_METHOD } from '@/types/app' // --- Mocks --- +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() const mockUpdateDatasetSetting = vi.fn().mockResolvedValue({}) @@ -55,8 +59,11 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + success: vi.fn(), + }, })) // --- Dataset factory --- @@ -311,7 +318,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { describe('Form Submission Validation → All Fields Together', () => { it('should reject empty name on save', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -322,10 +329,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() }) diff --git a/web/app/components/__tests__/browser-initializer.spec.ts b/web/__tests__/instrumentation-client.spec.ts similarity index 100% rename from web/app/components/__tests__/browser-initializer.spec.ts rename to web/__tests__/instrumentation-client.spec.ts diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 8edb6705d4..dd5a18b724 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -5,15 +5,21 @@ * upload handling, and task status polling. Verifies the complete plugin * installation pipeline from source discovery to completion. */ -import { beforeEach, describe, expect, it, vi } from 'vitest' -vi.mock('@/config', () => ({ - GITHUB_ACCESS_TOKEN: '', -})) +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { checkForUpdates, fetchReleases, handleUpload } from '@/app/components/plugins/install-plugin/hooks' const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockUploadGitHub = vi.fn() @@ -22,10 +28,6 @@ vi.mock('@/service/plugins', () => ({ checkTaskStatus: vi.fn(), })) -const { useGitHubReleases, useGitHubUpload } = await import( - '@/app/components/plugins/install-plugin/hooks', -) - describe('Plugin Installation Flow Integration', () => { beforeEach(() => { vi.clearAllMocks() @@ -36,22 +38,22 @@ describe('Plugin Installation Flow Integration', () => { it('fetches releases, checks for updates, and uploads the new version', async () => { const mockReleases = [ { - tag_name: 'v2.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v2.difypkg', name: 'plugin-v2.difypkg' }], + tag: 'v2.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v2.difypkg' }], }, { - tag_name: 'v1.5.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.5.difypkg', name: 'plugin-v1.5.difypkg' }], + tag: 'v1.5.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.5.difypkg' }], }, { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.difypkg', name: 'plugin-v1.difypkg' }], + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.difypkg' }], }, ] ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve(mockReleases), + json: () => Promise.resolve({ releases: mockReleases }), }) mockUploadGitHub.mockResolvedValue({ @@ -59,8 +61,6 @@ describe('Plugin Installation Flow Integration', () => { unique_identifier: 'test-plugin:2.0.0', }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') expect(releases).toHaveLength(3) expect(releases[0].tag_name).toBe('v2.0.0') @@ -69,7 +69,6 @@ describe('Plugin Installation Flow Integration', () => { expect(needUpdate).toBe(true) expect(toastProps.message).toContain('v2.0.0') - const { handleUpload } = useGitHubUpload() const onSuccess = vi.fn() const result = await handleUpload( 'https://github.com/test-org/test-repo', @@ -96,18 +95,16 @@ describe('Plugin Installation Flow Integration', () => { it('handles no new version available', async () => { const mockReleases = [ { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.difypkg', name: 'plugin-v1.difypkg' }], + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.difypkg' }], }, ] ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve(mockReleases), + json: () => Promise.resolve({ releases: mockReleases }), }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') const { needUpdate, toastProps } = checkForUpdates(releases, 'v1.0.0') @@ -119,11 +116,9 @@ describe('Plugin Installation Flow Integration', () => { it('handles empty releases', async () => { ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve([]), + json: () => Promise.resolve({ releases: [] }), }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') expect(releases).toHaveLength(0) @@ -139,7 +134,6 @@ describe('Plugin Installation Flow Integration', () => { status: 404, }) - const { fetchReleases } = useGitHubReleases() const releases = await fetchReleases('nonexistent-org', 'nonexistent-repo') expect(releases).toEqual([]) @@ -151,7 +145,6 @@ describe('Plugin Installation Flow Integration', () => { it('handles upload failure gracefully', async () => { mockUploadGitHub.mockRejectedValue(new Error('Upload failed')) - const { handleUpload } = useGitHubUpload() const onSuccess = vi.fn() await expect( diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index 3950bdf7ee..3a5f2272ed 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -109,7 +109,7 @@ beforeAll(() => { disconnect = vi.fn(() => undefined) unobserve = vi.fn(() => undefined) } - // @ts-expect-error jsdom does not implement IntersectionObserver + // @ts-expect-error test DOM typings do not guarantee IntersectionObserver here globalThis.IntersectionObserver = MockIntersectionObserver }) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx index b3a9bd7abc..1b8d64b911 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx @@ -12,15 +12,15 @@ vi.mock('ahooks', async (importOriginal) => { } }) -vi.mock('react-slider', () => ({ - default: (props: { className?: string, min?: number, max?: number, value: number, onChange: (value: number) => void }) => ( +vi.mock('@/app/components/base/ui/slider', () => ({ + Slider: (props: { className?: string, min?: number, max?: number, value: number, onValueChange: (value: number) => void }) => ( props.onChange(Number(e.target.value))} + onChange={e => props.onValueChange(Number(e.target.value))} /> ), })) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx index ec42e946dd..bce4e74aab 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Unblur } from '@/app/components/base/icons/src/vender/solid/education' -import Slider from '@/app/components/base/slider' +import { Slider } from '@/app/components/base/ui/slider' import { DEFAULT_AGENT_PROMPT, MAX_ITERATIONS_NUM } from '@/config' import ItemPanel from './item-panel' @@ -105,12 +105,13 @@ const AgentSetting: FC = ({ min={maxIterationsMin} max={MAX_ITERATIONS_NUM} value={tempPayload.max_iteration} - onChange={(value) => { + onValueChange={(value) => { setTempPayload({ ...tempPayload, max_iteration: value, }) }} + aria-label={t('agent.setting.maximumIterations.name', { ns: 'appDebug' })} /> { />, ) - const weightedScoreSlider = screen.getAllByRole('slider') - .find(slider => slider.getAttribute('aria-valuemax') === '1') - expect(weightedScoreSlider).toBeDefined() - await user.click(weightedScoreSlider!) + const weightedScoreSlider = screen.getByLabelText('dataset.weightedScore.semantic') + weightedScoreSlider.focus() const callsBefore = onChange.mock.calls.length await user.keyboard('{ArrowRight}') diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx index 9366039414..024432112d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx @@ -3,7 +3,7 @@ import type { DatasetConfigs } from '@/models/debug' import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel, @@ -75,7 +75,7 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-param const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as MockedFunction const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction -let toastNotifySpy: MockInstance +let toastErrorSpy: MockInstance const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { return { @@ -140,7 +140,7 @@ describe('dataset-config/params-config', () => { beforeEach(() => { vi.clearAllMocks() vi.useRealTimers() - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + toastErrorSpy = vi.spyOn(toast, 'error').mockImplementation(() => '') mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ modelList: [], defaultModel: undefined, @@ -154,7 +154,7 @@ describe('dataset-config/params-config', () => { }) afterEach(() => { - toastNotifySpy.mockRestore() + toastErrorSpy.mockRestore() }) // Rendering tests (REQUIRED) @@ -254,10 +254,7 @@ describe('dataset-config/params-config', () => { await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'appDebug.datasetConfig.rerankModelRequired', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('appDebug.datasetConfig.rerankModelRequired') expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 692ae12022..89410203df 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { @@ -66,10 +66,7 @@ const ParamsConfig = ({ } } if (errMsg) { - Toast.notify({ - type: 'error', - message: errMsg, - }) + toast.error(errMsg) } return !errMsg } diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.css b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.css deleted file mode 100644 index ef9350645a..0000000000 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.css +++ /dev/null @@ -1,7 +0,0 @@ -.weightedScoreSliderTrack { - background: var(--color-util-colors-blue-light-blue-light-500) !important; -} - -.weightedScoreSliderTrack-1 { - background: transparent !important; -} diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx index 7729830348..8e9348c77a 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx @@ -3,6 +3,8 @@ import userEvent from '@testing-library/user-event' import WeightedScore from './weighted-score' describe('WeightedScore', () => { + const getSliderInput = () => screen.getByLabelText('dataset.weightedScore.semantic') + beforeEach(() => { vi.clearAllMocks() }) @@ -48,8 +50,8 @@ describe('WeightedScore', () => { render() // Act - await user.tab() - const slider = screen.getByRole('slider') + const slider = getSliderInput() + slider.focus() expect(slider).toHaveFocus() const callsBefore = onChange.mock.calls.length await user.keyboard('{ArrowRight}') @@ -69,9 +71,8 @@ describe('WeightedScore', () => { render() // Act - await user.tab() - const slider = screen.getByRole('slider') - expect(slider).toHaveFocus() + const slider = getSliderInput() + expect(slider).toBeDisabled() await user.keyboard('{ArrowRight}') // Assert diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index 40beef52e8..d4ce935a4d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -1,9 +1,13 @@ +import type { CSSProperties } from 'react' import { noop } from 'es-toolkit/function' import { memo } from 'react' import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/base/slider' -import { cn } from '@/utils/classnames' -import './weighted-score.css' +import { Slider } from '@/app/components/base/ui/slider' + +const weightedScoreSliderStyle: CSSProperties & Record<'--slider-track' | '--slider-range', string> = { + '--slider-track': 'var(--color-util-colors-teal-teal-500)', + '--slider-range': 'var(--color-util-colors-blue-light-blue-light-500)', +} const formatNumber = (value: number) => { if (value > 0 && value < 1) @@ -33,24 +37,26 @@ const WeightedScore = ({ return (
- !readonly && onChange({ value: [v, (10 - v * 10) / 10] })} - trackClassName="weightedScoreSliderTrack" - disabled={readonly} - /> +
+ !readonly && onChange({ value: [v, (10 - v * 10) / 10] })} + disabled={readonly} + aria-label={t('weightedScore.semantic', { ns: 'dataset' })} + /> +
-
+
{t('weightedScore.semantic', { ns: 'dataset' })}
{formatNumber(value.value[0])}
-
+
{formatNumber(value.value[1])}
{t('weightedScore.keyword', { ns: 'dataset' })} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 188086246a..389ab189e9 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -556,8 +556,8 @@ describe('DebugWithMultipleModel', () => { ) const twoItems = screen.getAllByTestId('debug-item') - expect(twoItems[0].style.width).toBe('calc(50% - 28px)') - expect(twoItems[1].style.width).toBe('calc(50% - 28px)') + expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') }) }) @@ -596,13 +596,13 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(2) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: [], @@ -620,19 +620,19 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(3) expectItemLayout(items[0], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[2], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(200% + 16px)) translateY(0)', classes: [], @@ -655,25 +655,25 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(4) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(0)', classes: ['mr-2', 'mb-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mb-2'], }) expectItemLayout(items[2], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(calc(100% + 8px))', classes: ['mr-2'], }) expectItemLayout(items[3], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))', classes: [], diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index b849b4f015..e933855ca8 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -1,6 +1,3 @@ -/** - * @vitest-environment jsdom - */ import type { ReactNode } from 'react' import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index dce9de190d..b6ca60bd7b 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -8,12 +8,14 @@ import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' +import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' -import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' -import CreateAppModal from '../explore/create-app-modal' -import TryApp from '../explore/try-app' import List from './list' +const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) +const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) + const Apps = () => { const { t } = useTranslation() diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 0d52bd468c..2ef344f816 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -5,11 +5,11 @@ import { useDebounceFn } from 'ahooks' import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagFilter from '@/app/components/base/tag-management/filter' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' -import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -205,12 +205,12 @@ const List: FC = ({ options={options} />
- + { ) const button = screen.getByRole('button', { name: 'Custom Style' }) expect(button).toHaveStyle({ - color: 'rgb(255, 0, 0)', - backgroundColor: 'rgb(0, 0, 255)', + color: 'red', + backgroundColor: 'blue', }) }) diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx index e1d8e52eac..00af15e24d 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.tsx +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -5,17 +5,12 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import * as React from 'react' import { useEffect } from 'react' -import { AMPLITUDE_API_KEY, IS_CLOUD_EDITION } from '@/config' +import { AMPLITUDE_API_KEY, isAmplitudeEnabled } from '@/config' export type IAmplitudeProps = { sessionReplaySampleRate?: number } -// Check if Amplitude should be enabled -export const isAmplitudeEnabled = () => { - return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY -} - // Map URL pathname to English page name for consistent Amplitude tracking const getEnglishPageName = (pathname: string): string => { // Remove leading slash and get the first segment @@ -59,7 +54,7 @@ const AmplitudeProvider: FC = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return // Initialize Amplitude diff --git a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx index b30da72091..5835634eb7 100644 --- a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx +++ b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx @@ -2,14 +2,24 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import { render } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' +import AmplitudeProvider from '../AmplitudeProvider' const mockConfig = vi.hoisted(() => ({ AMPLITUDE_API_KEY: 'test-api-key', IS_CLOUD_EDITION: true, })) -vi.mock('@/config', () => mockConfig) +vi.mock('@/config', () => ({ + get AMPLITUDE_API_KEY() { + return mockConfig.AMPLITUDE_API_KEY + }, + get IS_CLOUD_EDITION() { + return mockConfig.IS_CLOUD_EDITION + }, + get isAmplitudeEnabled() { + return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY + }, +})) vi.mock('@amplitude/analytics-browser', () => ({ init: vi.fn(), @@ -27,22 +37,6 @@ describe('AmplitudeProvider', () => { mockConfig.IS_CLOUD_EDITION = true }) - describe('isAmplitudeEnabled', () => { - it('returns true when cloud edition and api key present', () => { - expect(isAmplitudeEnabled()).toBe(true) - }) - - it('returns false when cloud edition but no api key', () => { - mockConfig.AMPLITUDE_API_KEY = '' - expect(isAmplitudeEnabled()).toBe(false) - }) - - it('returns false when not cloud edition', () => { - mockConfig.IS_CLOUD_EDITION = false - expect(isAmplitudeEnabled()).toBe(false) - }) - }) - describe('Component', () => { it('initializes amplitude when enabled', () => { render() diff --git a/web/app/components/base/amplitude/__tests__/index.spec.ts b/web/app/components/base/amplitude/__tests__/index.spec.ts deleted file mode 100644 index 2d7ad6ab84..0000000000 --- a/web/app/components/base/amplitude/__tests__/index.spec.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { describe, expect, it } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' -import indexDefault, { - isAmplitudeEnabled as indexIsAmplitudeEnabled, - resetUser, - setUserId, - setUserProperties, - trackEvent, -} from '../index' -import { - resetUser as utilsResetUser, - setUserId as utilsSetUserId, - setUserProperties as utilsSetUserProperties, - trackEvent as utilsTrackEvent, -} from '../utils' - -describe('Amplitude index exports', () => { - it('exports AmplitudeProvider as default', () => { - expect(indexDefault).toBe(AmplitudeProvider) - }) - - it('exports isAmplitudeEnabled', () => { - expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) - }) - - it('exports utils', () => { - expect(resetUser).toBe(utilsResetUser) - expect(setUserId).toBe(utilsSetUserId) - expect(setUserProperties).toBe(utilsSetUserProperties) - expect(trackEvent).toBe(utilsTrackEvent) - }) -}) diff --git a/web/app/components/base/amplitude/__tests__/utils.spec.ts b/web/app/components/base/amplitude/__tests__/utils.spec.ts index ecbc57e387..f1ff5db1e3 100644 --- a/web/app/components/base/amplitude/__tests__/utils.spec.ts +++ b/web/app/components/base/amplitude/__tests__/utils.spec.ts @@ -20,8 +20,10 @@ const MockIdentify = vi.hoisted(() => }, ) -vi.mock('../AmplitudeProvider', () => ({ - isAmplitudeEnabled: () => mockState.enabled, +vi.mock('@/config', () => ({ + get isAmplitudeEnabled() { + return mockState.enabled + }, })) vi.mock('@amplitude/analytics-browser', () => ({ diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts index acc792339e..44cbf728e2 100644 --- a/web/app/components/base/amplitude/index.ts +++ b/web/app/components/base/amplitude/index.ts @@ -1,2 +1,2 @@ -export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { default } from './lazy-amplitude-provider' export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/lazy-amplitude-provider.tsx b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx new file mode 100644 index 0000000000..5dfa0e7b53 --- /dev/null +++ b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx @@ -0,0 +1,11 @@ +'use client' + +import type { FC } from 'react' +import type { IAmplitudeProps } from './AmplitudeProvider' +import dynamic from '@/next/dynamic' + +const AmplitudeProvider = dynamic(() => import('./AmplitudeProvider'), { ssr: false }) + +const LazyAmplitudeProvider: FC = props => + +export default LazyAmplitudeProvider diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts index 57b96243ec..8faa8e852e 100644 --- a/web/app/components/base/amplitude/utils.ts +++ b/web/app/components/base/amplitude/utils.ts @@ -1,5 +1,5 @@ import * as amplitude from '@amplitude/analytics-browser' -import { isAmplitudeEnabled } from './AmplitudeProvider' +import { isAmplitudeEnabled } from '@/config' /** * Track custom event @@ -7,7 +7,7 @@ import { isAmplitudeEnabled } from './AmplitudeProvider' * @param eventProperties Event properties (optional) */ export const trackEvent = (eventName: string, eventProperties?: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.track(eventName, eventProperties) } @@ -17,7 +17,7 @@ export const trackEvent = (eventName: string, eventProperties?: Record { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.setUserId(userId) } @@ -27,7 +27,7 @@ export const setUserId = (userId: string) => { * @param properties User properties */ export const setUserProperties = (properties: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return const identifyEvent = new amplitude.Identify() Object.entries(properties).forEach(([key, value]) => { @@ -40,7 +40,7 @@ export const setUserProperties = (properties: Record) => { * Reset user (e.g., when user logs out) */ export const resetUser = () => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.reset() } diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/avatar/index.tsx index 2d55ec2720..f53e1f8985 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/avatar/index.tsx @@ -1,8 +1,9 @@ import type { ImageLoadingStatus } from '@base-ui/react/avatar' +import type * as React from 'react' import { Avatar as BaseAvatar } from '@base-ui/react/avatar' import { cn } from '@/utils/classnames' -const SIZES = { +const avatarSizeClasses = { 'xxs': { root: 'size-4', text: 'text-[7px]' }, 'xs': { root: 'size-5', text: 'text-[8px]' }, 'sm': { root: 'size-6', text: 'text-[10px]' }, @@ -13,7 +14,7 @@ const SIZES = { '3xl': { root: 'size-16', text: 'text-2xl' }, } as const -export type AvatarSize = keyof typeof SIZES +export type AvatarSize = keyof typeof avatarSizeClasses export type AvatarProps = { name: string @@ -23,7 +24,61 @@ export type AvatarProps = { onLoadingStatusChange?: (status: ImageLoadingStatus) => void } -const BASE_CLASS = 'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600' +export type AvatarRootProps = React.ComponentPropsWithRef & { + size?: AvatarSize +} + +export function AvatarRoot({ + size = 'md', + className, + ...props +}: AvatarRootProps) { + return ( + + ) +} + +export type AvatarImageProps = React.ComponentPropsWithRef + +export function AvatarImage({ + className, + ...props +}: AvatarImageProps) { + return ( + + ) +} + +export type AvatarFallbackProps = React.ComponentPropsWithRef & { + size?: AvatarSize +} + +export function AvatarFallback({ + size = 'md', + className, + ...props +}: AvatarFallbackProps) { + return ( + + ) +} export const Avatar = ({ name, @@ -32,21 +87,18 @@ export const Avatar = ({ className, onLoadingStatusChange, }: AvatarProps) => { - const sizeConfig = SIZES[size] - return ( - + {avatar && ( - )} - + {name?.[0]?.toLocaleUpperCase()} - - + + ) } diff --git a/web/app/components/base/chat/chat/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/__tests__/index.spec.tsx index 781b5e86f3..0100b059f0 100644 --- a/web/app/components/base/chat/chat/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/index.spec.tsx @@ -8,10 +8,10 @@ import Chat from '../index' // ─── Why each mock exists ───────────────────────────────────────────────────── // // Answer – transitively pulls Markdown (rehype/remark/katex), AgentContent, -// WorkflowProcessItem and Operation; none can resolve in jsdom. +// WorkflowProcessItem and Operation; none can resolve in the test DOM runtime. // Question – pulls Markdown, copy-to-clipboard, react-textarea-autosize. // ChatInputArea – pulls js-audio-recorder (requires Web Audio API unavailable in -// jsdom) and VoiceInput / FileContextProvider chains. +// the test DOM runtime) and VoiceInput / FileContextProvider chains. // PromptLogModal– pulls CopyFeedbackNew and deep modal dep chain. // AgentLogModal – pulls @remixicon/react (causes lint push error), useClickAway // from ahooks, and AgentLogDetail (workflow graph renderer). diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index b47fec1d0a..5881f565a4 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -158,7 +158,7 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { rootNodes.push(questionNode) } else { - map[parentMessageId]?.children!.push(questionNode) + map[parentMessageId].children!.push(questionNode) } } } diff --git a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx index d8e00780b1..8839798c15 100644 --- a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import dayjs from '../../utils/dayjs' import Calendar from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx index 910faf9cd4..199ed4ee41 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen, within } from '@testing-library/react' import dayjs, { isDayjsObject } from '../../utils/dayjs' import TimePicker from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx index 332b87cb30..ac0b6d0f57 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx @@ -93,7 +93,6 @@ const ConfigParamModal: FC = ({ className="mt-1" value={(annotationConfig.score_threshold || ANNOTATION_DEFAULT.score_threshold) * 100} onChange={(val) => { - /* v8 ignore next -- callback dispatch depends on react-slider drag mechanics that are flaky in jsdom. @preserve */ setAnnotationConfig({ ...annotationConfig, score_threshold: val / 100, diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx index 2bc30e4ead..ffa9c33043 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/__tests__/index.spec.tsx @@ -1,20 +1,9 @@ import { render, screen } from '@testing-library/react' import ScoreSlider from '../index' -vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider', () => ({ - default: ({ value, onChange, min, max }: { value: number, onChange: (v: number) => void, min: number, max: number }) => ( - onChange(Number(e.target.value))} - /> - ), -})) - describe('ScoreSlider', () => { + const getSliderInput = () => screen.getByLabelText('appDebug.feature.annotation.scoreThreshold.title') + beforeEach(() => { vi.clearAllMocks() }) @@ -22,7 +11,7 @@ describe('ScoreSlider', () => { it('should render the slider', () => { render() - expect(screen.getByTestId('slider')).toBeInTheDocument() + expect(getSliderInput()).toBeInTheDocument() }) it('should display easy match and accurate match labels', () => { @@ -37,14 +26,14 @@ describe('ScoreSlider', () => { it('should render with custom className', () => { const { container } = render() - // Verifying the component renders successfully with a custom className - expect(screen.getByTestId('slider')).toBeInTheDocument() + expect(getSliderInput()).toBeInTheDocument() expect(container.firstChild).toHaveClass('custom-class') }) it('should pass value to the slider', () => { render() - expect(screen.getByTestId('slider')).toHaveValue('95') + expect(getSliderInput()).toHaveValue('95') + expect(screen.getByText('0.95')).toBeInTheDocument() }) }) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/__tests__/index.spec.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/__tests__/index.spec.tsx deleted file mode 100644 index 815e8ffe49..0000000000 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/__tests__/index.spec.tsx +++ /dev/null @@ -1,50 +0,0 @@ -import { render, screen } from '@testing-library/react' -import Slider from '../index' - -describe('BaseSlider', () => { - beforeEach(() => { - vi.clearAllMocks() - }) - - it('should render the slider component', () => { - render() - - expect(screen.getByRole('slider')).toBeInTheDocument() - }) - - it('should display the formatted value in the thumb', () => { - render() - - expect(screen.getByText('0.85')).toBeInTheDocument() - }) - - it('should use default min/max/step when not provided', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '0') - expect(slider).toHaveAttribute('aria-valuemax', '100') - expect(slider).toHaveAttribute('aria-valuenow', '50') - }) - - it('should use custom min/max/step when provided', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '80') - expect(slider).toHaveAttribute('aria-valuemax', '100') - expect(slider).toHaveAttribute('aria-valuenow', '90') - }) - - it('should handle NaN value as 0', () => { - render() - - expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '0') - }) - - it('should pass disabled prop', () => { - render() - - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') - }) -}) diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx deleted file mode 100644 index 509426c08e..0000000000 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import ReactSlider from 'react-slider' -import { cn } from '@/utils/classnames' -import s from './style.module.css' - -type ISliderProps = { - className?: string - value: number - max?: number - min?: number - step?: number - disabled?: boolean - onChange: (value: number) => void -} - -const Slider: React.FC = ({ className, max, min, step, value, disabled, onChange }) => { - return ( - ( -
-
-
- {(state.valueNow / 100).toFixed(2)} -
-
-
- )} - /> - ) -} - -export default Slider diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css deleted file mode 100644 index 8ef23b54b5..0000000000 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css +++ /dev/null @@ -1,20 +0,0 @@ -.slider { - position: relative; -} - -.slider.disabled { - opacity: 0.6; -} - -.slider-thumb:focus { - outline: none; -} - -.slider-track { - background-color: #528BFF; - height: 2px; -} - -.slider-track-1 { - background-color: #E5E7EB; -} diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx index c6fb1a0b4e..0363eb2820 100644 --- a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider' +import { Slider } from '@/app/components/base/ui/slider' type Props = { className?: string @@ -10,23 +10,42 @@ type Props = { onChange: (value: number) => void } +const clamp = (value: number, min: number, max: number) => { + if (!Number.isFinite(value)) + return min + + return Math.min(Math.max(value, min), max) +} + const ScoreSlider: FC = ({ className, value, onChange, }) => { const { t } = useTranslation() + const safeValue = clamp(value, 80, 100) return (
-
+
+
+ {(safeValue / 100).toFixed(2)} +
diff --git a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx index 69496903a6..de9cc7ecd0 100644 --- a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx +++ b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx @@ -37,11 +37,11 @@ const FileFromLinkOrLocal = ({ const { handleLoadFileFromLink } = useFile(fileConfig) const disabled = !!fileConfig.number_limits && files.length >= fileConfig.number_limits const fileLinkPlaceholder = t('fileUploader.pasteFileLinkInputPlaceholder', { ns: 'common' }) - /* v8 ignore next -- fallback for missing i18n key is not reliably testable under current global translation mocks in jsdom @preserve */ + /* v8 ignore next -- fallback for a missing i18n key is not reliably testable under the current global translation mocks in the test DOM runtime. @preserve */ const fileLinkPlaceholderText = fileLinkPlaceholder || '' const handleSaveUrl = () => { - /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in jsdom click flow @preserve */ + /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in the current test click flow. @preserve */ if (!url) return diff --git a/web/app/components/base/icons/__tests__/utils.spec.ts b/web/app/components/base/icons/__tests__/utils.spec.ts index a25f39111d..f8534038bf 100644 --- a/web/app/components/base/icons/__tests__/utils.spec.ts +++ b/web/app/components/base/icons/__tests__/utils.spec.ts @@ -62,7 +62,7 @@ describe('generate icon base utils', () => { const { container } = render(generate(node, 'key')) // to svg element expect(container.firstChild).toHaveClass('container') - expect(container.querySelector('span')).toHaveStyle({ color: 'rgb(0, 0, 255)' }) + expect(container.querySelector('span')).toHaveStyle({ color: 'blue' }) }) // add not has children diff --git a/web/app/components/base/input/__tests__/index.spec.tsx b/web/app/components/base/input/__tests__/index.spec.tsx index 2c5b563a12..dfab8617c2 100644 --- a/web/app/components/base/input/__tests__/index.spec.tsx +++ b/web/app/components/base/input/__tests__/index.spec.tsx @@ -99,7 +99,7 @@ describe('Input component', () => { render() const input = screen.getByPlaceholderText(/input/i) expect(input).toHaveClass(customClass) - expect(input).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(input).toHaveStyle({ color: 'red' }) }) it('applies large size variant correctly', () => { diff --git a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx index 308232fd0f..a16686801c 100644 --- a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx @@ -1,6 +1,6 @@ -import { createRequire } from 'node:module' import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import * as echarts from 'echarts' import { Theme } from '@/types/app' import CodeBlock from '../code-block' @@ -10,17 +10,28 @@ type UseThemeReturn = { } const mockUseTheme = vi.fn<() => UseThemeReturn>(() => ({ theme: Theme.light })) -const require = createRequire(import.meta.url) -const echartsCjs = require('echarts') as { - getInstanceByDom: (dom: HTMLDivElement | null) => { - resize: (opts?: { width?: string, height?: string }) => void - } | null -} +const mockEcharts = vi.hoisted(() => { + const state = { + finishedHandler: undefined as undefined | ((event?: unknown) => void), + echartsInstance: { + resize: vi.fn<(opts?: { width?: string, height?: string }) => void>(), + trigger: vi.fn((eventName: string, event?: unknown) => { + if (eventName === 'finished') + state.finishedHandler?.(event) + }), + }, + getInstanceByDom: vi.fn(() => state.echartsInstance), + } + + return state +}) let clientWidthSpy: { mockRestore: () => void } | null = null let clientHeightSpy: { mockRestore: () => void } | null = null let offsetWidthSpy: { mockRestore: () => void } | null = null let offsetHeightSpy: { mockRestore: () => void } | null = null +let consoleErrorSpy: ReturnType | null = null +let consoleWarnSpy: ReturnType | null = null type AudioContextCtor = new () => unknown type WindowWithLegacyAudio = Window & { @@ -59,6 +70,42 @@ vi.mock('@/hooks/use-theme', () => ({ default: () => mockUseTheme(), })) +vi.mock('echarts', () => ({ + getInstanceByDom: mockEcharts.getInstanceByDom, +})) + +vi.mock('echarts-for-react', async () => { + const React = await vi.importActual('react') + + const MockReactEcharts = React.forwardRef(({ + onChartReady, + onEvents, + }: { + onChartReady?: (instance: typeof mockEcharts.echartsInstance) => void + onEvents?: { finished?: (event?: unknown) => void } + }, ref: React.ForwardedRef<{ getEchartsInstance: () => typeof mockEcharts.echartsInstance }>) => { + React.useImperativeHandle(ref, () => ({ + getEchartsInstance: () => mockEcharts.echartsInstance, + })) + + React.useEffect(() => { + mockEcharts.finishedHandler = onEvents?.finished + onChartReady?.(mockEcharts.echartsInstance) + onEvents?.finished?.({}) + return () => { + mockEcharts.finishedHandler = undefined + } + }, [onChartReady, onEvents]) + + return
+ }) + + return { + __esModule: true, + default: MockReactEcharts, + } +}) + vi.mock('@/app/components/base/mermaid', () => ({ __esModule: true, default: ({ PrimitiveCode }: { PrimitiveCode: string }) =>
{PrimitiveCode}
, @@ -74,15 +121,17 @@ const findEchartsHost = async () => { const findEchartsInstance = async () => { const host = await findEchartsHost() await waitFor(() => { - expect(echartsCjs.getInstanceByDom(host)).toBeTruthy() + expect(echarts.getInstanceByDom(host)).toBeTruthy() }) - return echartsCjs.getInstanceByDom(host)! + return echarts.getInstanceByDom(host)! } describe('CodeBlock', () => { beforeEach(() => { vi.clearAllMocks() mockUseTheme.mockReturnValue({ theme: Theme.light }) + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) clientWidthSpy = vi.spyOn(HTMLElement.prototype, 'clientWidth', 'get').mockReturnValue(900) clientHeightSpy = vi.spyOn(HTMLElement.prototype, 'clientHeight', 'get').mockReturnValue(400) offsetWidthSpy = vi.spyOn(HTMLElement.prototype, 'offsetWidth', 'get').mockReturnValue(900) @@ -98,6 +147,10 @@ describe('CodeBlock', () => { afterEach(() => { vi.useRealTimers() + consoleErrorSpy?.mockRestore() + consoleWarnSpy?.mockRestore() + consoleErrorSpy = null + consoleWarnSpy = null clientWidthSpy?.mockRestore() clientHeightSpy?.mockRestore() offsetWidthSpy?.mockRestore() diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index b36d8d7788..412c61d52d 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -85,13 +85,30 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any const processedRef = useRef(false) // Track if content was successfully processed const isInitialRenderRef = useRef(true) // Track if this is initial render const chartInstanceRef = useRef(null) // Direct reference to ECharts instance - const resizeTimerRef = useRef(null) // For debounce handling + const resizeTimerRef = useRef | null>(null) // For debounce handling + const chartReadyTimerRef = useRef | null>(null) const finishedEventCountRef = useRef(0) // Track finished event trigger count const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') const isDarkMode = theme === Theme.dark + const clearResizeTimer = useCallback(() => { + if (!resizeTimerRef.current) + return + + clearTimeout(resizeTimerRef.current) + resizeTimerRef.current = null + }, []) + + const clearChartReadyTimer = useCallback(() => { + if (!chartReadyTimerRef.current) + return + + clearTimeout(chartReadyTimerRef.current) + chartReadyTimerRef.current = null + }, []) + const echartsStyle = useMemo(() => ({ height: '350px', width: '100%', @@ -104,26 +121,27 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any // Debounce resize operations const debouncedResize = useCallback(() => { - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() resizeTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() resizeTimerRef.current = null }, 200) - }, []) + }, [clearResizeTimer]) // Handle ECharts instance initialization const handleChartReady = useCallback((instance: any) => { chartInstanceRef.current = instance // Force resize to ensure timeline displays correctly - setTimeout(() => { + clearChartReadyTimer() + chartReadyTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() + chartReadyTimerRef.current = null }, 200) - }, []) + }, [clearChartReadyTimer]) // Store event handlers in useMemo to avoid recreating them const echartsEvents = useMemo(() => ({ @@ -157,10 +175,20 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any return () => { window.removeEventListener('resize', handleResize) - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null } - }, [language, debouncedResize]) + }, [language, debouncedResize, clearResizeTimer, clearChartReadyTimer]) + + useEffect(() => { + return () => { + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null + echartsRef.current = null + } + }, [clearResizeTimer, clearChartReadyTimer]) // Process chart data when content changes useEffect(() => { // Only process echarts content diff --git a/web/app/components/base/node-status/__tests__/index.spec.tsx b/web/app/components/base/node-status/__tests__/index.spec.tsx index f74af4965e..37b12946c8 100644 --- a/web/app/components/base/node-status/__tests__/index.spec.tsx +++ b/web/app/components/base/node-status/__tests__/index.spec.tsx @@ -41,7 +41,7 @@ describe('NodeStatus', () => { it('applies styleCss correctly', () => { const { container } = render() - expect(container.firstChild).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(container.firstChild).toHaveStyle({ color: 'red' }) }) it('applies iconClassName to the icon', () => { diff --git a/web/app/components/base/pagination/__tests__/pagination.spec.tsx b/web/app/components/base/pagination/__tests__/pagination.spec.tsx index 776802ff19..06eac9bfbd 100644 --- a/web/app/components/base/pagination/__tests__/pagination.spec.tsx +++ b/web/app/components/base/pagination/__tests__/pagination.spec.tsx @@ -131,7 +131,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(2) }) @@ -142,7 +142,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -213,7 +213,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(1) }) @@ -225,7 +225,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -318,7 +318,7 @@ describe('Pagination', () => { /> ), }) - fireEvent.keyPress(screen.getByText('4'), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText('4').closest('a')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(3) // 0-indexed }) diff --git a/web/app/components/base/pagination/pagination.tsx b/web/app/components/base/pagination/pagination.tsx index 0eb06b594c..b258090d80 100644 --- a/web/app/components/base/pagination/pagination.tsx +++ b/web/app/components/base/pagination/pagination.tsx @@ -50,7 +50,7 @@ export const PrevButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) previous() @@ -85,7 +85,7 @@ export const NextButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) next() @@ -140,7 +140,7 @@ export const PageButton = ({ }) || undefined } tabIndex={0} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { if (event.key === 'Enter') pagination.setCurrentPage(page - 1) }} diff --git a/web/app/components/base/param-item/__tests__/index-slider.spec.tsx b/web/app/components/base/param-item/__tests__/index-slider.spec.tsx index 0048b89644..6448835844 100644 --- a/web/app/components/base/param-item/__tests__/index-slider.spec.tsx +++ b/web/app/components/base/param-item/__tests__/index-slider.spec.tsx @@ -14,12 +14,14 @@ describe('ParamItem Slider onChange', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('Test Param') + it('should divide slider value by 100 when max < 5', async () => { const user = userEvent.setup() render() - const slider = screen.getByRole('slider') + const slider = getSlider() - await user.click(slider) + slider.focus() await user.keyboard('{ArrowRight}') // max=1 < 5, so slider value change (50->51) becomes 0.51 @@ -29,9 +31,9 @@ describe('ParamItem Slider onChange', () => { it('should not divide slider value when max >= 5', async () => { const user = userEvent.setup() render() - const slider = screen.getByRole('slider') + const slider = getSlider() - await user.click(slider) + slider.focus() await user.keyboard('{ArrowRight}') // max=10 >= 5, so value remains raw (5->6) diff --git a/web/app/components/base/param-item/__tests__/index.spec.tsx b/web/app/components/base/param-item/__tests__/index.spec.tsx index 96591446c8..889662c87d 100644 --- a/web/app/components/base/param-item/__tests__/index.spec.tsx +++ b/web/app/components/base/param-item/__tests__/index.spec.tsx @@ -17,6 +17,8 @@ describe('ParamItem', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('Test Param') + describe('Rendering', () => { it('should render the parameter name', () => { render() @@ -54,7 +56,7 @@ describe('ParamItem', () => { render() expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) }) @@ -74,7 +76,7 @@ describe('ParamItem', () => { it('should disable Slider when enable is false', () => { render() - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') + expect(getSlider()).toBeDisabled() }) it('should set switch value based on enable prop', () => { @@ -135,7 +137,7 @@ describe('ParamItem', () => { await user.clear(input) expect(defaultProps.onChange).toHaveBeenLastCalledWith('test_param', 0) - expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '0') + expect(getSlider()).toHaveAttribute('aria-valuenow', '0') await user.tab() @@ -166,12 +168,12 @@ describe('ParamItem', () => { await user.type(input, '1.5') expect(defaultProps.onChange).toHaveBeenLastCalledWith('test_param', 1) - expect(screen.getByRole('slider')).toHaveAttribute('aria-valuenow', '100') + expect(getSlider()).toHaveAttribute('aria-valuenow', '100') }) it('should pass scaled value to slider when max < 5', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() // When max < 5, slider value = value * 100 = 50 expect(slider).toHaveAttribute('aria-valuenow', '50') @@ -179,7 +181,7 @@ describe('ParamItem', () => { it('should pass raw value to slider when max >= 5', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() // When max >= 5, slider value = value = 5 expect(slider).toHaveAttribute('aria-valuenow', '5') @@ -212,15 +214,15 @@ describe('ParamItem', () => { render() // Slider should get value * 100 = 50, min * 100 = 0, max * 100 = 100 - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemax', '100') + const slider = getSlider() + expect(slider).toHaveAttribute('max', '100') }) it('should not scale slider value when max >= 5', () => { render() - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemax', '10') + const slider = getSlider() + expect(slider).toHaveAttribute('max', '10') }) it('should expose default minimum of 0 when min is not provided', () => { diff --git a/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx b/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx index 54a13e1b74..ddc286942b 100644 --- a/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx +++ b/web/app/components/base/param-item/__tests__/score-threshold-item.spec.tsx @@ -14,6 +14,8 @@ describe('ScoreThresholdItem', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('appDebug.datasetConfig.score_threshold') + describe('Rendering', () => { it('should render the translated parameter name', () => { render() @@ -32,7 +34,7 @@ describe('ScoreThresholdItem', () => { render() expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) }) @@ -63,7 +65,7 @@ describe('ScoreThresholdItem', () => { render() expect(screen.getByRole('textbox')).toBeDisabled() - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') + expect(getSlider()).toBeDisabled() }) }) diff --git a/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx b/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx index 1b8555213b..c84fd50518 100644 --- a/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx +++ b/web/app/components/base/param-item/__tests__/top-k-item.spec.tsx @@ -19,6 +19,8 @@ describe('TopKItem', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('appDebug.datasetConfig.top_k') + describe('Rendering', () => { it('should render the translated parameter name', () => { render() @@ -37,7 +39,7 @@ describe('TopKItem', () => { render() expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) }) @@ -52,7 +54,7 @@ describe('TopKItem', () => { render() expect(screen.getByRole('textbox')).toBeDisabled() - expect(screen.getByRole('slider')).toHaveAttribute('aria-disabled', 'true') + expect(getSlider()).toBeDisabled() }) }) @@ -77,10 +79,10 @@ describe('TopKItem', () => { it('should render slider with max >= 5 so no scaling is applied', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() // max=10 >= 5 so slider shows raw values - expect(slider).toHaveAttribute('aria-valuemax', '10') + expect(slider).toHaveAttribute('max', '10') }) it('should not render a switch (no hasSwitch prop)', () => { @@ -116,9 +118,9 @@ describe('TopKItem', () => { it('should call onChange with integer value when slider changes', async () => { const user = userEvent.setup() render() - const slider = screen.getByRole('slider') + const slider = getSlider() - await user.click(slider) + slider.focus() await user.keyboard('{ArrowRight}') expect(defaultProps.onChange).toHaveBeenLastCalledWith('top_k', 3) diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 63af4bca84..56999fc6ea 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { FC } from 'react' -import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' +import { Slider } from '@/app/components/base/ui/slider' import { NumberField, NumberFieldControls, @@ -78,7 +78,8 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, value={max < 5 ? value * 100 : value} min={min < 1 ? min * 100 : min} max={max < 5 ? max * 100 : max} - onChange={value => onChange(id, value / (max < 5 ? 100 : 1))} + onValueChange={value => onChange(id, value / (max < 5 ? 100 : 1))} + aria-label={name} />
diff --git a/web/app/components/base/premium-badge/__tests__/index.spec.tsx b/web/app/components/base/premium-badge/__tests__/index.spec.tsx index af8ace22f0..d107c07e52 100644 --- a/web/app/components/base/premium-badge/__tests__/index.spec.tsx +++ b/web/app/components/base/premium-badge/__tests__/index.spec.tsx @@ -41,6 +41,6 @@ describe('PremiumBadge', () => { ) const badge = screen.getByText('Premium') expect(badge).toBeInTheDocument() - expect(badge).toHaveStyle('background-color: rgb(255, 0, 0)') // Note: React converts 'red' to 'rgb(255, 0, 0)' + expect(badge).toHaveStyle('background-color: red') }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx index dd2f74f7e5..a16ae9d823 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx @@ -3,13 +3,10 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' import { BLUR_COMMAND, - COMMAND_PRIORITY_EDITOR, FOCUS_COMMAND, - KEY_ESCAPE_COMMAND, } from 'lexical' import OnBlurBlock from '../on-blur-or-focus-block' import { CaptureEditorPlugin } from '../test-utils' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const renderOnBlurBlock = (props?: { onBlur?: () => void @@ -75,7 +72,7 @@ describe('OnBlurBlock', () => { expect(onFocus).toHaveBeenCalledTimes(1) }) - it('should call onBlur and dispatch escape after delay when blur target is not var-search-input', async () => { + it('should call onBlur when blur target is not var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -85,14 +82,6 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) let handled = false act(() => { @@ -101,18 +90,9 @@ describe('OnBlurBlock', () => { expect(handled).toBe(true) expect(onBlur).toHaveBeenCalledTimes(1) - expect(onEscape).not.toHaveBeenCalled() - - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() }) - it('should dispatch delayed escape when onBlur callback is not provided', async () => { + it('should handle blur when onBlur callback is not provided', async () => { const { getEditor } = renderOnBlurBlock() await waitFor(() => { @@ -121,28 +101,16 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) + let handled = false act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - vi.advanceTimersByTime(200) + handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) }) - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() + expect(handled).toBe(true) }) - it('should skip onBlur and delayed escape when blur target is var-search-input', async () => { + it('should skip onBlur when blur target is var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -152,31 +120,17 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() const target = document.createElement('input') target.classList.add('var-search-input') - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - let handled = false act(() => { handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(target)) }) - act(() => { - vi.advanceTimersByTime(200) - }) expect(handled).toBe(true) expect(onBlur).not.toHaveBeenCalled() - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() }) it('should handle focus command when onFocus callback is not provided', async () => { @@ -198,59 +152,6 @@ describe('OnBlurBlock', () => { }) }) - describe('Clear timeout command', () => { - it('should clear scheduled escape timeout when clear command is dispatched', async () => { - const { getEditor } = renderOnBlurBlock({ onBlur: vi.fn() }) - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - - act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() - }) - - it('should handle clear command when no timeout is scheduled', async () => { - const { getEditor } = renderOnBlurBlock() - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - - let handled = false - act(() => { - handled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - - expect(handled).toBe(true) - }) - }) - describe('Lifecycle cleanup', () => { it('should unregister commands when component unmounts', async () => { const { getEditor, unmount } = renderOnBlurBlock() @@ -266,16 +167,13 @@ describe('OnBlurBlock', () => { let blurHandled = true let focusHandled = true - let clearHandled = true act(() => { blurHandled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) focusHandled = editor!.dispatchCommand(FOCUS_COMMAND, createFocusEvent()) - clearHandled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) expect(blurHandled).toBe(false) expect(focusHandled).toBe(false) - expect(clearHandled).toBe(false) }) }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx index 8f6a72a7de..4283910c31 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx @@ -1,14 +1,13 @@ import type { LexicalEditor } from 'lexical' import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' -import { $getRoot, COMMAND_PRIORITY_EDITOR } from 'lexical' +import { $getRoot } from 'lexical' import { CustomTextNode } from '../custom-text/node' import { CaptureEditorPlugin } from '../test-utils' import UpdateBlock, { PROMPT_EDITOR_INSERT_QUICKLY, PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, } from '../update-block' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const { mockUseEventEmitterContextContext } = vi.hoisted(() => ({ mockUseEventEmitterContextContext: vi.fn(), @@ -157,7 +156,7 @@ describe('UpdateBlock', () => { }) describe('Quick insert event', () => { - it('should insert slash and dispatch clear command when quick insert event matches instance id', async () => { + it('should insert slash when quick insert event matches instance id', async () => { const { emit, getEditor } = setup({ instanceId: 'instance-1' }) await waitFor(() => { @@ -168,13 +167,6 @@ describe('UpdateBlock', () => { selectRootEnd(editor!) - const clearCommandHandler = vi.fn(() => true) - const unregister = editor!.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - clearCommandHandler, - COMMAND_PRIORITY_EDITOR, - ) - emit({ type: PROMPT_EDITOR_INSERT_QUICKLY, instanceId: 'instance-1', @@ -183,9 +175,6 @@ describe('UpdateBlock', () => { await waitFor(() => { expect(readEditorText(editor!)).toBe('/') }) - expect(clearCommandHandler).toHaveBeenCalledTimes(1) - - unregister() }) it('should ignore quick insert event when instance id does not match', async () => { diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 6cc6c3a67f..51b14b76c8 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -23,6 +23,8 @@ import { $createTextNode, $getRoot, $setSelection, + BLUR_COMMAND, + FOCUS_COMMAND, KEY_ESCAPE_COMMAND, } from 'lexical' import * as React from 'react' @@ -631,4 +633,180 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => { // With a single option group, the only divider should be the workflow-var/options separator. expect(document.querySelectorAll('.bg-divider-subtle')).toHaveLength(1) }) + + describe('blur/focus menu visibility', () => { + it('hides the menu after a 200ms delay when blur command is dispatched', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + vi.useRealTimers() + }) + + it('restores menu visibility when focus command is dispatched after blur hides it', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + vi.useRealTimers() + + await setEditorText(editor, '{', true) + await waitFor(() => { + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + }) + }) + + it('cancels the blur timer when focus arrives before the 200ms timeout', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('cancels a pending blur timer when a subsequent blur targets var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + const varInput = document.createElement('input') + varInput.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: varInput })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('does not hide the menu when blur target is var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + const target = document.createElement('input') + target.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: target })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + }) }) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx index 8001a2755b..bebc1b59af 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx @@ -21,11 +21,19 @@ import { } from '@floating-ui/react' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin' -import { KEY_ESCAPE_COMMAND } from 'lexical' +import { mergeRegister } from '@lexical/utils' +import { + BLUR_COMMAND, + COMMAND_PRIORITY_EDITOR, + FOCUS_COMMAND, + KEY_ESCAPE_COMMAND, +} from 'lexical' import { Fragment, memo, useCallback, + useEffect, + useRef, useState, } from 'react' import ReactDOM from 'react-dom' @@ -87,6 +95,46 @@ const ComponentPicker = ({ }) const [queryString, setQueryString] = useState(null) + const [blurHidden, setBlurHidden] = useState(false) + const blurTimerRef = useRef | null>(null) + + const clearBlurTimer = useCallback(() => { + if (blurTimerRef.current) { + clearTimeout(blurTimerRef.current) + blurTimerRef.current = null + } + }, []) + + useEffect(() => { + const unregister = mergeRegister( + editor.registerCommand( + BLUR_COMMAND, + (event) => { + clearBlurTimer() + const target = event?.relatedTarget as HTMLElement + if (!target?.classList?.contains('var-search-input')) + blurTimerRef.current = setTimeout(() => setBlurHidden(true), 200) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + editor.registerCommand( + FOCUS_COMMAND, + () => { + clearBlurTimer() + setBlurHidden(false) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + ) + + return () => { + if (blurTimerRef.current) + clearTimeout(blurTimerRef.current) + unregister() + } + }, [editor, clearBlurTimer]) eventEmitter?.useSubscription((v: any) => { if (v.type === INSERT_VARIABLE_VALUE_BLOCK_COMMAND) @@ -159,6 +207,8 @@ const ComponentPicker = ({ anchorElementRef, { options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, ) => { + if (blurHidden) + return null if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show))) return null @@ -240,7 +290,7 @@ const ComponentPicker = ({ } ) - }, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) + }, [blurHidden, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) return ( void @@ -20,35 +18,13 @@ const OnBlurBlock: FC = ({ }) => { const [editor] = useLexicalComposerContext() - const ref = useRef | null>(null) - useEffect(() => { - const clearHideMenuTimeout = () => { - if (ref.current) { - clearTimeout(ref.current) - ref.current = null - } - } - - const unregister = mergeRegister( - editor.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - () => { - clearHideMenuTimeout() - return true - }, - COMMAND_PRIORITY_EDITOR, - ), + return mergeRegister( editor.registerCommand( BLUR_COMMAND, (event) => { - // Check if the clicked target element is var-search-input const target = event?.relatedTarget as HTMLElement if (!target?.classList?.contains('var-search-input')) { - clearHideMenuTimeout() - ref.current = setTimeout(() => { - editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) - }, 200) if (onBlur) onBlur() } @@ -66,11 +42,6 @@ const OnBlurBlock: FC = ({ COMMAND_PRIORITY_EDITOR, ), ) - - return () => { - clearHideMenuTimeout() - unregister() - } }, [editor, onBlur, onFocus]) return null diff --git a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx index abe6ea9a45..7dcda803f2 100644 --- a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx @@ -141,7 +141,7 @@ export default function ShortcutsPopupPlugin({ const portalRef = useRef(null) const lastSelectionRef = useRef(null) - /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/jsdom). @preserve */ + /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/test DOM runtime). @preserve */ const containerEl = useMemo(() => container ?? (typeof document !== 'undefined' ? document.body : null), [container]) const useContainer = !!containerEl && containerEl !== document.body @@ -210,7 +210,7 @@ export default function ShortcutsPopupPlugin({ if (rect.width === 0 && rect.height === 0) { const root = editor.getRootElement() - /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in jsdom is unreliable. @preserve */ + /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in the test DOM runtime is unreliable. @preserve */ if (root) { const sc = range.startContainer const node = sc.nodeType === Node.ELEMENT_NODE diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx index bf89a259af..2d83573b1f 100644 --- a/web/app/components/base/prompt-editor/plugins/update-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx @@ -3,7 +3,6 @@ import { $insertNodes } from 'lexical' import { useEventEmitterContextContext } from '@/context/event-emitter' import { textToEditorState } from '../utils' import { CustomTextNode } from './custom-text/node' -import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block' export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER' export const PROMPT_EDITOR_INSERT_QUICKLY = 'PROMPT_EDITOR_INSERT_QUICKLY' @@ -30,8 +29,6 @@ const UpdateBlock = ({ editor.update(() => { const textNode = new CustomTextNode('/') $insertNodes([textNode]) - - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) } }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx index ca4973b830..1591dc44f9 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx @@ -9,7 +9,6 @@ import { $insertNodes, COMMAND_PRIORITY_EDITOR } from 'lexical' import { Type } from '@/app/components/workflow/nodes/llm/types' import { BlockEnum } from '@/app/components/workflow/types' import { - CLEAR_HIDE_MENU_TIMEOUT, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND, INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, UPDATE_WORKFLOW_NODES_MAP, @@ -134,7 +133,6 @@ describe('WorkflowVariableBlock', () => { const insertHandler = mockRegisterCommand.mock.calls[0][1] as (variables: string[]) => boolean const result = insertHandler(['node-1', 'answer']) - expect(mockDispatchCommand).toHaveBeenCalledWith(CLEAR_HIDE_MENU_TIMEOUT, undefined) expect($createWorkflowVariableBlockNode).toHaveBeenCalledWith( ['node-1', 'answer'], workflowNodesMap, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 76b2795803..c8cac64d19 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -18,7 +18,6 @@ import { export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND') export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND') -export const CLEAR_HIDE_MENU_TIMEOUT = createCommand('CLEAR_HIDE_MENU_TIMEOUT') export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') export type WorkflowVariableBlockProps = { @@ -49,7 +48,6 @@ const WorkflowVariableBlock = memo(({ editor.registerCommand( INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) diff --git a/web/app/components/base/slider/__tests__/index.spec.tsx b/web/app/components/base/slider/__tests__/index.spec.tsx deleted file mode 100644 index bb1f030689..0000000000 --- a/web/app/components/base/slider/__tests__/index.spec.tsx +++ /dev/null @@ -1,77 +0,0 @@ -import { act, render, screen } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { describe, expect, it, vi } from 'vitest' -import Slider from '../index' - -describe('Slider Component', () => { - it('should render with correct default ARIA limits and current value', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '0') - expect(slider).toHaveAttribute('aria-valuemax', '100') - expect(slider).toHaveAttribute('aria-valuenow', '50') - }) - - it('should apply custom min, max, and step values', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '5') - expect(slider).toHaveAttribute('aria-valuemax', '20') - expect(slider).toHaveAttribute('aria-valuenow', '10') - }) - - it('should default to 0 if the value prop is NaN', () => { - render() - - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuenow', '0') - }) - - it('should call onChange when arrow keys are pressed', async () => { - const user = userEvent.setup() - const onChange = vi.fn() - - render() - - const slider = screen.getByRole('slider') - - await act(async () => { - slider.focus() - await user.keyboard('{ArrowRight}') - }) - - expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange).toHaveBeenCalledWith(21, 0) - }) - - it('should not trigger onChange when disabled', async () => { - const user = userEvent.setup() - const onChange = vi.fn() - render() - - const slider = screen.getByRole('slider') - - expect(slider).toHaveAttribute('aria-disabled', 'true') - - await act(async () => { - slider.focus() - await user.keyboard('{ArrowRight}') - }) - - expect(onChange).not.toHaveBeenCalled() - }) - - it('should apply custom class names', () => { - render( - , - ) - - const sliderWrapper = screen.getByRole('slider').closest('.outer-test') - expect(sliderWrapper).toBeInTheDocument() - - const thumb = screen.getByRole('slider') - expect(thumb).toHaveClass('thumb-test') - }) -}) diff --git a/web/app/components/base/slider/index.stories.tsx b/web/app/components/base/slider/index.stories.tsx deleted file mode 100644 index bde937ffad..0000000000 --- a/web/app/components/base/slider/index.stories.tsx +++ /dev/null @@ -1,635 +0,0 @@ -import type { Meta, StoryObj } from '@storybook/nextjs-vite' -import { useState } from 'react' -import Slider from '.' - -const meta = { - title: 'Base/Data Entry/Slider', - component: Slider, - parameters: { - layout: 'centered', - docs: { - description: { - component: 'Slider component for selecting a numeric value within a range. Built on react-slider with customizable min/max/step values.', - }, - }, - }, - tags: ['autodocs'], - argTypes: { - value: { - control: 'number', - description: 'Current slider value', - }, - min: { - control: 'number', - description: 'Minimum value (default: 0)', - }, - max: { - control: 'number', - description: 'Maximum value (default: 100)', - }, - step: { - control: 'number', - description: 'Step increment (default: 1)', - }, - disabled: { - control: 'boolean', - description: 'Disabled state', - }, - }, - args: { - onChange: (value) => { - console.log('Slider value:', value) - }, - }, -} satisfies Meta - -export default meta -type Story = StoryObj - -// Interactive demo wrapper -const SliderDemo = (args: any) => { - const [value, setValue] = useState(args.value || 50) - - return ( -
- { - setValue(v) - console.log('Slider value:', v) - }} - /> -
- Value: - {' '} - {value} -
-
- ) -} - -// Default state -export const Default: Story = { - render: args => , - args: { - value: 50, - min: 0, - max: 100, - step: 1, - disabled: false, - }, -} - -// With custom range -export const CustomRange: Story = { - render: args => , - args: { - value: 25, - min: 0, - max: 50, - step: 1, - disabled: false, - }, -} - -// With step increment -export const WithStepIncrement: Story = { - render: args => , - args: { - value: 50, - min: 0, - max: 100, - step: 10, - disabled: false, - }, -} - -// Decimal values -export const DecimalValues: Story = { - render: args => , - args: { - value: 2.5, - min: 0, - max: 5, - step: 0.5, - disabled: false, - }, -} - -// Disabled state -export const Disabled: Story = { - render: args => , - args: { - value: 75, - min: 0, - max: 100, - step: 1, - disabled: true, - }, -} - -// Real-world example - Volume control -const VolumeControlDemo = () => { - const [volume, setVolume] = useState(70) - - const getVolumeIcon = (vol: number) => { - if (vol === 0) - return '🔇' - if (vol < 33) - return '🔈' - if (vol < 66) - return '🔉' - return '🔊' - } - - return ( -
-
-

Volume Control

- {getVolumeIcon(volume)} -
- -
- Mute - - {volume} - % - - Max -
-
- ) -} - -export const VolumeControl: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Brightness control -const BrightnessControlDemo = () => { - const [brightness, setBrightness] = useState(80) - - return ( -
-
-

Screen Brightness

- ☀️ -
- -
-
- Preview at - {' '} - {brightness} - % brightness -
-
-
- ) -} - -export const BrightnessControl: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Price range filter -const PriceRangeFilterDemo = () => { - const [maxPrice, setMaxPrice] = useState(500) - const minPrice = 0 - - const products = [ - { name: 'Product A', price: 150 }, - { name: 'Product B', price: 350 }, - { name: 'Product C', price: 600 }, - { name: 'Product D', price: 250 }, - { name: 'Product E', price: 450 }, - ] - - const filteredProducts = products.filter(p => p.price >= minPrice && p.price <= maxPrice) - - return ( -
-

Filter by Price

-
-
- Maximum Price - - $ - {maxPrice} - -
- -
-
-
- Showing - {' '} - {filteredProducts.length} - {' '} - of - {' '} - {products.length} - {' '} - products -
-
- {filteredProducts.map(product => ( -
- {product.name} - - $ - {product.price} - -
- ))} -
-
-
- ) -} - -export const PriceRangeFilter: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Temperature selector -const TemperatureSelectorDemo = () => { - const [temperature, setTemperature] = useState(22) - const fahrenheit = ((temperature * 9) / 5 + 32).toFixed(1) - - return ( -
-

Thermostat Control

-
- -
-
-
-
Celsius
-
- {temperature} - °C -
-
-
-
Fahrenheit
-
- {fahrenheit} - °F -
-
-
-
- {temperature < 18 && '🥶 Too cold'} - {temperature >= 18 && temperature <= 24 && '😊 Comfortable'} - {temperature > 24 && '🥵 Too warm'} -
-
- ) -} - -export const TemperatureSelector: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Progress/completion slider -const ProgressSliderDemo = () => { - const [progress, setProgress] = useState(65) - - return ( -
-

Project Completion

- -
-
- Progress - - {progress} - % - -
-
-
- = 25 ? '✅' : '⏳'}>Planning - 25% -
-
- = 50 ? '✅' : '⏳'}>Development - 50% -
-
- = 75 ? '✅' : '⏳'}>Testing - 75% -
-
- = 100 ? '✅' : '⏳'}>Deployment - 100% -
-
-
-
- ) -} - -export const ProgressSlider: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Zoom control -const ZoomControlDemo = () => { - const [zoom, setZoom] = useState(100) - - return ( -
-

Zoom Level

-
- -
- -
- -
-
- 50% - - {zoom} - % - - 200% -
-
-
Preview content
-
-
- ) -} - -export const ZoomControl: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - AI model parameters -const AIModelParametersDemo = () => { - const [temperature, setTemperature] = useState(0.7) - const [maxTokens, setMaxTokens] = useState(2000) - const [topP, setTopP] = useState(0.9) - - return ( -
-

Model Configuration

-
-
-
- - {temperature} -
- -

- Controls randomness. Lower is more focused, higher is more creative. -

-
- -
-
- - {maxTokens} -
- -

- Maximum length of generated response. -

-
- -
-
- - {topP} -
- -

- Nucleus sampling threshold. -

-
-
-
-
- Temperature: - {' '} - {temperature} -
-
- Max Tokens: - {' '} - {maxTokens} -
-
- Top P: - {' '} - {topP} -
-
-
- ) -} - -export const AIModelParameters: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Real-world example - Image quality selector -const ImageQualitySelectorDemo = () => { - const [quality, setQuality] = useState(80) - - const getQualityLabel = (q: number) => { - if (q < 50) - return 'Low' - if (q < 70) - return 'Medium' - if (q < 90) - return 'High' - return 'Maximum' - } - - const estimatedSize = Math.round((quality / 100) * 5) - - return ( -
-

Image Export Quality

- -
-
-
Quality
-
{getQualityLabel(quality)}
-
- {quality} - % -
-
-
-
File Size
-
- ~ - {estimatedSize} - {' '} - MB -
-
Estimated
-
-
-
- ) -} - -export const ImageQualitySelector: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Multiple sliders -const MultipleSlidersDemo = () => { - const [red, setRed] = useState(128) - const [green, setGreen] = useState(128) - const [blue, setBlue] = useState(128) - - const rgbColor = `rgb(${red}, ${green}, ${blue})` - - return ( -
-

RGB Color Picker

-
-
-
- - {red} -
- -
-
-
- - {green} -
- -
-
-
- - {blue} -
- -
-
-
-
-
-
Color Value
-
{rgbColor}
-
- # - {red.toString(16).padStart(2, '0')} - {green.toString(16).padStart(2, '0')} - {blue.toString(16).padStart(2, '0')} -
-
-
-
- ) -} - -export const MultipleSliders: Story = { - render: () => , - parameters: { controls: { disable: true } }, -} as unknown as Story - -// Interactive playground -export const Playground: Story = { - render: args => , - args: { - value: 50, - min: 0, - max: 100, - step: 1, - disabled: false, - }, -} diff --git a/web/app/components/base/slider/index.tsx b/web/app/components/base/slider/index.tsx deleted file mode 100644 index 4e4656f590..0000000000 --- a/web/app/components/base/slider/index.tsx +++ /dev/null @@ -1,43 +0,0 @@ -import ReactSlider from 'react-slider' -import { cn } from '@/utils/classnames' -import './style.css' - -type ISliderProps = { - className?: string - thumbClassName?: string - trackClassName?: string - value: number - max?: number - min?: number - step?: number - disabled?: boolean - onChange: (value: number) => void -} - -const Slider: React.FC = ({ - className, - thumbClassName, - trackClassName, - max, - min, - step, - value, - disabled, - onChange, -}) => { - return ( - - ) -} - -export default Slider diff --git a/web/app/components/base/slider/style.css b/web/app/components/base/slider/style.css deleted file mode 100644 index 5d87fb0897..0000000000 --- a/web/app/components/base/slider/style.css +++ /dev/null @@ -1,11 +0,0 @@ -.slider.disabled { - opacity: 0.6; -} - -.slider-track { - background-color: var(--color-components-slider-range); -} - -.slider-track-1 { - background-color: var(--color-components-slider-track); -} diff --git a/web/app/components/base/ui/slider/__tests__/index.spec.tsx b/web/app/components/base/ui/slider/__tests__/index.spec.tsx new file mode 100644 index 0000000000..f34de5010d --- /dev/null +++ b/web/app/components/base/ui/slider/__tests__/index.spec.tsx @@ -0,0 +1,73 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { describe, expect, it, vi } from 'vitest' +import { Slider } from '../index' + +describe('Slider', () => { + const getSliderInput = () => screen.getByLabelText('Value') + + it('should render with correct default ARIA limits and current value', () => { + render() + + const slider = getSliderInput() + expect(slider).toHaveAttribute('min', '0') + expect(slider).toHaveAttribute('max', '100') + expect(slider).toHaveAttribute('aria-valuenow', '50') + }) + + it('should apply custom min, max, and step values', () => { + render() + + const slider = getSliderInput() + expect(slider).toHaveAttribute('min', '5') + expect(slider).toHaveAttribute('max', '20') + expect(slider).toHaveAttribute('aria-valuenow', '10') + }) + + it('should clamp non-finite values to min', () => { + render() + + expect(getSliderInput()).toHaveAttribute('aria-valuenow', '5') + }) + + it('should call onValueChange when arrow keys are pressed', async () => { + const user = userEvent.setup() + const onValueChange = vi.fn() + + render() + + const slider = getSliderInput() + + await act(async () => { + slider.focus() + await user.keyboard('{ArrowRight}') + }) + + expect(onValueChange).toHaveBeenCalledTimes(1) + expect(onValueChange).toHaveBeenLastCalledWith(21, expect.anything()) + }) + + it('should not trigger onValueChange when disabled', async () => { + const user = userEvent.setup() + const onValueChange = vi.fn() + render() + + const slider = getSliderInput() + + expect(slider).toBeDisabled() + + await act(async () => { + slider.focus() + await user.keyboard('{ArrowRight}') + }) + + expect(onValueChange).not.toHaveBeenCalled() + }) + + it('should apply custom class names on root', () => { + const { container } = render() + + const sliderWrapper = container.querySelector('.outer-test') + expect(sliderWrapper).toBeInTheDocument() + }) +}) diff --git a/web/app/components/base/ui/slider/index.stories.tsx b/web/app/components/base/ui/slider/index.stories.tsx new file mode 100644 index 0000000000..b61a6cb288 --- /dev/null +++ b/web/app/components/base/ui/slider/index.stories.tsx @@ -0,0 +1,92 @@ +import type { Meta, StoryObj } from '@storybook/nextjs-vite' +import type * as React from 'react' +import { useState } from 'react' +import { Slider } from '.' + +const meta = { + title: 'Base UI/Data Entry/Slider', + component: Slider, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Single-value horizontal slider built on Base UI.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + value: { + control: 'number', + }, + min: { + control: 'number', + }, + max: { + control: 'number', + }, + step: { + control: 'number', + }, + disabled: { + control: 'boolean', + }, + }, +} satisfies Meta + +export default meta + +type Story = StoryObj + +function SliderDemo({ + value: initialValue = 50, + defaultValue: _defaultValue, + ...args +}: React.ComponentProps) { + const [value, setValue] = useState(initialValue) + + return ( +
+ +
+ {value} +
+
+ ) +} + +export const Default: Story = { + render: args => , + args: { + value: 50, + min: 0, + max: 100, + step: 1, + }, +} + +export const Decimal: Story = { + render: args => , + args: { + value: 0.5, + min: 0, + max: 1, + step: 0.1, + }, +} + +export const Disabled: Story = { + render: args => , + args: { + value: 75, + min: 0, + max: 100, + step: 1, + disabled: true, + }, +} diff --git a/web/app/components/base/ui/slider/index.tsx b/web/app/components/base/ui/slider/index.tsx new file mode 100644 index 0000000000..8e1dc969bc --- /dev/null +++ b/web/app/components/base/ui/slider/index.tsx @@ -0,0 +1,100 @@ +'use client' + +import { Slider as BaseSlider } from '@base-ui/react/slider' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type SliderRootProps = BaseSlider.Root.Props +type SliderThumbProps = BaseSlider.Thumb.Props + +type SliderBaseProps = Pick< + SliderRootProps, + 'onValueChange' | 'min' | 'max' | 'step' | 'disabled' | 'name' +> & Pick & { + className?: string +} + +type ControlledSliderProps = SliderBaseProps & { + value: number + defaultValue?: never +} + +type UncontrolledSliderProps = SliderBaseProps & { + value?: never + defaultValue?: number +} + +export type SliderProps = ControlledSliderProps | UncontrolledSliderProps + +const sliderRootClassName = 'group/slider relative inline-flex w-full data-[disabled]:opacity-30' +const sliderControlClassName = cn( + 'relative flex h-5 w-full touch-none select-none items-center', + 'data-[disabled]:cursor-not-allowed', +) +const sliderTrackClassName = cn( + 'relative h-1 w-full overflow-hidden rounded-full', + 'bg-[var(--slider-track,var(--color-components-slider-track))]', +) +const sliderIndicatorClassName = cn( + 'h-full rounded-full', + 'bg-[var(--slider-range,var(--color-components-slider-range))]', +) +const sliderThumbClassName = cn( + 'block h-5 w-2 shrink-0 rounded-[3px] border-[0.5px]', + 'border-[var(--slider-knob-border,var(--color-components-slider-knob-border))]', + 'bg-[var(--slider-knob,var(--color-components-slider-knob))] shadow-sm', + 'transition-[background-color,border-color,box-shadow,opacity] motion-reduce:transition-none', + 'hover:bg-[var(--slider-knob-hover,var(--color-components-slider-knob-hover))]', + 'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-components-slider-knob-border-hover focus-visible:ring-offset-0', + 'active:shadow-md', + 'group-data-[disabled]/slider:bg-[var(--slider-knob-disabled,var(--color-components-slider-knob-disabled))]', + 'group-data-[disabled]/slider:border-[var(--slider-knob-border,var(--color-components-slider-knob-border))]', + 'group-data-[disabled]/slider:shadow-none', +) + +const getSafeValue = (value: number | undefined, min: number) => { + if (value === undefined) + return undefined + + return Number.isFinite(value) ? value : min +} + +export function Slider({ + value, + defaultValue, + onValueChange, + min = 0, + max = 100, + step = 1, + disabled = false, + name, + className, + 'aria-label': ariaLabel, + 'aria-labelledby': ariaLabelledby, +}: SliderProps) { + return ( + + + + + + + + + ) +} diff --git a/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx b/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx new file mode 100644 index 0000000000..1441653c9c --- /dev/null +++ b/web/app/components/billing/partner-stack/__tests__/cookie-recorder.spec.tsx @@ -0,0 +1,45 @@ +import { render } from '@testing-library/react' +import PartnerStackCookieRecorder from '../cookie-recorder' + +let isCloudEdition = true + +const saveOrUpdate = vi.fn() + +vi.mock('@/config', () => ({ + get IS_CLOUD_EDITION() { + return isCloudEdition + }, +})) + +vi.mock('../use-ps-info', () => ({ + default: () => ({ + saveOrUpdate, + }), +})) + +describe('PartnerStackCookieRecorder', () => { + beforeEach(() => { + vi.clearAllMocks() + isCloudEdition = true + }) + + it('should call saveOrUpdate once on mount when running in cloud edition', () => { + render() + + expect(saveOrUpdate).toHaveBeenCalledTimes(1) + }) + + it('should not call saveOrUpdate when not running in cloud edition', () => { + isCloudEdition = false + + render() + + expect(saveOrUpdate).not.toHaveBeenCalled() + }) + + it('should render null', () => { + const { container } = render() + + expect(container.innerHTML).toBe('') + }) +}) diff --git a/web/app/components/billing/partner-stack/cookie-recorder.tsx b/web/app/components/billing/partner-stack/cookie-recorder.tsx new file mode 100644 index 0000000000..3c75b2973c --- /dev/null +++ b/web/app/components/billing/partner-stack/cookie-recorder.tsx @@ -0,0 +1,19 @@ +'use client' + +import { useEffect } from 'react' +import { IS_CLOUD_EDITION } from '@/config' +import usePSInfo from './use-ps-info' + +const PartnerStackCookieRecorder = () => { + const { saveOrUpdate } = usePSInfo() + + useEffect(() => { + if (!IS_CLOUD_EDITION) + return + saveOrUpdate() + }, []) + + return null +} + +export default PartnerStackCookieRecorder diff --git a/web/app/components/billing/partner-stack/use-ps-info.ts b/web/app/components/billing/partner-stack/use-ps-info.ts index 7c45d7ef87..5a83dec0e5 100644 --- a/web/app/components/billing/partner-stack/use-ps-info.ts +++ b/web/app/components/billing/partner-stack/use-ps-info.ts @@ -24,7 +24,7 @@ const usePSInfo = () => { }] = useBoolean(false) const { mutateAsync } = useBindPartnerStackInfo() // Save to top domain. cloud.dify.ai => .dify.ai - const domain = globalThis.location.hostname.replace('cloud', '') + const domain = globalThis.location?.hostname.replace('cloud', '') const saveOrUpdate = useCallback(() => { if (!psPartnerKey || !psClickId) @@ -39,7 +39,7 @@ const usePSInfo = () => { path: '/', domain, }) - }, [psPartnerKey, psClickId, isPSChanged]) + }, [psPartnerKey, psClickId, isPSChanged, domain]) const bind = useCallback(async () => { if (psPartnerKey && psClickId && !hasBind) { @@ -59,7 +59,7 @@ const usePSInfo = () => { Cookies.remove(PARTNER_STACK_CONFIG.cookieName, { path: '/', domain }) setBind() } - }, [psPartnerKey, psClickId, mutateAsync, hasBind, setBind]) + }, [psPartnerKey, psClickId, hasBind, domain, setBind, mutateAsync]) return { psPartnerKey, psClickId, diff --git a/web/app/components/billing/pricing/__tests__/header.spec.tsx b/web/app/components/billing/pricing/__tests__/header.spec.tsx index 0aadc3b0ce..cb8991ff42 100644 --- a/web/app/components/billing/pricing/__tests__/header.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/header.spec.tsx @@ -1,12 +1,14 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' -import { Dialog } from '@/app/components/base/ui/dialog' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import Header from '../header' function renderHeader(onClose: () => void) { return render( -
+ +
+
, ) } @@ -24,7 +26,7 @@ describe('Header', () => { expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) @@ -33,7 +35,7 @@ describe('Header', () => { const handleClose = vi.fn() renderHeader(handleClose) - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' })) expect(handleClose).toHaveBeenCalledTimes(1) }) @@ -41,11 +43,11 @@ describe('Header', () => { describe('Edge Cases', () => { it('should render structural elements with translation keys', () => { - const { container } = renderHeader(vi.fn()) + renderHeader(vi.fn()) - expect(container.querySelector('span')).toBeInTheDocument() - expect(container.querySelector('p')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) }) diff --git a/web/app/components/billing/pricing/__tests__/index.spec.tsx b/web/app/components/billing/pricing/__tests__/index.spec.tsx index 36848cd463..a8d0a4329e 100644 --- a/web/app/components/billing/pricing/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/index.spec.tsx @@ -68,6 +68,7 @@ describe('Pricing', () => { it('should render pricing header and localized footer link', () => { render() + expect(screen.getByRole('dialog', { name: 'billing.plansCommon.title.plans' })).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByTestId('pricing-link')).toHaveAttribute('href', 'https://dify.ai/en/pricing#plans-and-features') }) diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 0d3fd965b0..1422ec1cb1 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -28,8 +28,9 @@ const Footer = ({ {t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })} diff --git a/web/app/components/billing/pricing/header.tsx b/web/app/components/billing/pricing/header.tsx index d0ffe100db..5ab1895667 100644 --- a/web/app/components/billing/pricing/header.tsx +++ b/web/app/components/billing/pricing/header.tsx @@ -1,5 +1,6 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' +import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog' import { cn } from '@/utils/classnames' import Button from '../../base/button' import DifyLogo from '../../base/logo/dify-logo' @@ -18,24 +19,25 @@ const Header = ({
-
+ - {t('plansCommon.title.plans', { ns: 'billing' })} - +
-

+ {t('plansCommon.title.description', { ns: 'billing' })} -

+ ))} diff --git a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts index 63ac30630e..f29d85b460 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts @@ -5,9 +5,15 @@ import { IndexingType } from '@/app/components/datasets/create/step-two' import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' import { useDatasetCardState } from '../use-dataset-card-state' -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, + error: mockToastError, }, })) @@ -299,7 +305,7 @@ describe('useDatasetCardState', () => { describe('Error Handling', () => { it('should show error toast when export pipeline fails', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') mockExportPipeline.mockRejectedValue(new Error('Export failed')) const dataset = createMockDataset({ pipeline_id: 'pipeline-1' }) @@ -311,14 +317,11 @@ describe('useDatasetCardState', () => { await result.current.handleExportPipeline() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should handle Response error in detectIsUsedByApp', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const mockResponse = new Response(JSON.stringify({ message: 'API Error' }), { status: 400, }) @@ -333,14 +336,11 @@ describe('useDatasetCardState', () => { await result.current.detectIsUsedByApp() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.stringContaining('API Error'), - }) + expect(toast.error).toHaveBeenCalledWith(expect.stringContaining('API Error')) }) it('should handle generic Error in detectIsUsedByApp', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') mockCheckUsage.mockRejectedValue(new Error('Network error')) const dataset = createMockDataset() @@ -352,14 +352,11 @@ describe('useDatasetCardState', () => { await result.current.detectIsUsedByApp() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Network error', - }) + expect(toast.error).toHaveBeenCalledWith('Network error') }) it('should handle error without message in detectIsUsedByApp', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') mockCheckUsage.mockRejectedValue({}) const dataset = createMockDataset() @@ -371,10 +368,7 @@ describe('useDatasetCardState', () => { await result.current.detectIsUsedByApp() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Unknown error', - }) + expect(toast.error).toHaveBeenCalledWith('dataset.unknownError') }) it('should handle exporting state correctly', async () => { diff --git a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts index 4bd8357f1c..850eee4364 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts @@ -2,7 +2,7 @@ import type { Tag } from '@/app/components/base/tag-management/constant' import type { DataSet } from '@/models/datasets' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useCheckDatasetUsage, useDeleteDataset } from '@/service/use-dataset-card' import { useExportPipelineDSL } from '@/service/use-pipeline' import { downloadBlob } from '@/utils/download' @@ -70,7 +70,7 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { - Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast.error(t('exportFailed', { ns: 'app' })) } finally { setExporting(false) @@ -93,10 +93,10 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO catch (e: unknown) { if (e instanceof Response) { const res = await e.json() - Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) + toast.error(res?.message || t('unknownError', { ns: 'dataset' })) } else { - Toast.notify({ type: 'error', message: (e as Error)?.message || 'Unknown error' }) + toast.error((e as Error)?.message || t('unknownError', { ns: 'dataset' })) } } }, [dataset.id, checkUsage, t]) @@ -104,7 +104,7 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO const onConfirmDelete = useCallback(async () => { try { await deleteDatasetMutation(dataset.id) - Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) + toast.success(t('datasetDeleted', { ns: 'dataset' })) onSuccess?.() } finally { diff --git a/web/app/components/datasets/settings/form/__tests__/index.spec.tsx b/web/app/components/datasets/settings/form/__tests__/index.spec.tsx index 9eeb62b8e8..a3a22e000b 100644 --- a/web/app/components/datasets/settings/form/__tests__/index.spec.tsx +++ b/web/app/components/datasets/settings/form/__tests__/index.spec.tsx @@ -6,6 +6,10 @@ import { RETRIEVE_METHOD } from '@/types/app' import { IndexingType } from '../../../create/step-two' import Form from '../index' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + // Mock contexts const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() @@ -189,9 +193,10 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + success: vi.fn(), }, })) @@ -391,7 +396,7 @@ describe('Form', () => { }) it('should show error when trying to save with empty name', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') render(
) // Clear the name @@ -403,10 +408,7 @@ describe('Form', () => { fireEvent.click(saveButton) await waitFor(() => { - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) }) diff --git a/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts b/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts index f27b542b1e..00462619aa 100644 --- a/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts +++ b/web/app/components/datasets/settings/form/hooks/__tests__/use-form-state.spec.ts @@ -6,6 +6,11 @@ import { RETRIEVE_METHOD } from '@/types/app' import { IndexingType } from '../../../../create/step-two' import { useFormState } from '../use-form-state' +const { mockToastSuccess, mockToastError } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) + // Mock contexts const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() @@ -122,9 +127,10 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: mockToastSuccess, + error: mockToastError, }, })) @@ -423,7 +429,7 @@ describe('useFormState', () => { describe('handleSave', () => { it('should show error toast when name is empty', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -434,14 +440,11 @@ describe('useFormState', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should show error toast when name is whitespace only', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -452,10 +455,7 @@ describe('useFormState', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should call updateDatasetSetting with correct params', async () => { @@ -477,7 +477,7 @@ describe('useFormState', () => { }) it('should show success toast on successful save', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) await act(async () => { @@ -485,10 +485,7 @@ describe('useFormState', () => { }) await waitFor(() => { - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'success', - message: expect.any(String), - }) + expect(toast.success).toHaveBeenCalledWith(expect.any(String)) }) }) @@ -553,7 +550,7 @@ describe('useFormState', () => { it('should show error toast on save failure', async () => { const { updateDatasetSetting } = await import('@/service/datasets') - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') vi.mocked(updateDatasetSetting).mockRejectedValueOnce(new Error('Network error')) const { result } = renderHook(() => useFormState()) @@ -562,10 +559,7 @@ describe('useFormState', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) }) it('should include partial_member_list when permission is partialMembers', async () => { diff --git a/web/app/components/datasets/settings/form/hooks/use-form-state.ts b/web/app/components/datasets/settings/form/hooks/use-form-state.ts index 614995d43a..d00534f7f4 100644 --- a/web/app/components/datasets/settings/form/hooks/use-form-state.ts +++ b/web/app/components/datasets/settings/form/hooks/use-form-state.ts @@ -6,7 +6,7 @@ import type { IconInfo, SummaryIndexSetting as SummaryIndexSettingType } from '@ import type { RetrievalConfig } from '@/types/app' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -123,12 +123,12 @@ export const useFormState = () => { return if (!name?.trim()) { - Toast.notify({ type: 'error', message: t('form.nameError', { ns: 'datasetSettings' }) }) + toast.error(t('form.nameError', { ns: 'datasetSettings' })) return } if (!isReRankModelSelected({ rerankModelList, retrievalConfig, indexMethod })) { - Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) }) + toast.error(t('datasetConfig.rerankModelRequired', { ns: 'appDebug' })) return } @@ -176,7 +176,7 @@ export const useFormState = () => { } await updateDatasetSetting({ datasetId: currentDataset!.id, body }) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) if (mutateDatasets) { await mutateDatasets() @@ -184,7 +184,7 @@ export const useFormState = () => { } } catch { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) } finally { setLoading(false) diff --git a/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx b/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx index ae2c17d880..7441274155 100644 --- a/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx +++ b/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx @@ -14,6 +14,8 @@ describe('IndexMethod', () => { vi.clearAllMocks() }) + const getKeywordSlider = () => screen.getByLabelText('datasetSettings.form.numberOfKeywords') + describe('Rendering', () => { it('should render without crashing', () => { render() @@ -123,8 +125,7 @@ describe('IndexMethod', () => { describe('KeywordNumber', () => { it('should render KeywordNumber component inside Economy option', () => { render() - // KeywordNumber has a slider - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getKeywordSlider()).toBeInTheDocument() }) it('should pass keywordNumber to KeywordNumber component', () => { diff --git a/web/app/components/datasets/settings/index-method/__tests__/keyword-number.spec.tsx b/web/app/components/datasets/settings/index-method/__tests__/keyword-number.spec.tsx index e7ba8af6f1..cd0d56bbeb 100644 --- a/web/app/components/datasets/settings/index-method/__tests__/keyword-number.spec.tsx +++ b/web/app/components/datasets/settings/index-method/__tests__/keyword-number.spec.tsx @@ -11,6 +11,8 @@ describe('KeyWordNumber', () => { vi.clearAllMocks() }) + const getSlider = () => screen.getByLabelText('datasetSettings.form.numberOfKeywords') + describe('Rendering', () => { it('should render without crashing', () => { render() @@ -31,8 +33,7 @@ describe('KeyWordNumber', () => { it('should render slider', () => { render() - // Slider has a slider role - expect(screen.getByRole('slider')).toBeInTheDocument() + expect(getSlider()).toBeInTheDocument() }) it('should render input number field', () => { @@ -61,7 +62,7 @@ describe('KeyWordNumber', () => { it('should pass correct value to slider', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() expect(slider).toHaveAttribute('aria-valuenow', '30') }) }) @@ -71,8 +72,7 @@ describe('KeyWordNumber', () => { const handleChange = vi.fn() render() - const slider = screen.getByRole('slider') - // Verify slider is rendered and interactive + const slider = getSlider() expect(slider).toBeInTheDocument() expect(slider).not.toBeDisabled() }) @@ -109,14 +109,14 @@ describe('KeyWordNumber', () => { describe('Slider Configuration', () => { it('should have max value of 50', () => { render() - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemax', '50') + const slider = getSlider() + expect(slider).toHaveAttribute('max', '50') }) it('should have min value of 0', () => { render() - const slider = screen.getByRole('slider') - expect(slider).toHaveAttribute('aria-valuemin', '0') + const slider = getSlider() + expect(slider).toHaveAttribute('min', '0') }) }) @@ -162,7 +162,7 @@ describe('KeyWordNumber', () => { describe('Accessibility', () => { it('should have accessible slider', () => { render() - const slider = screen.getByRole('slider') + const slider = getSlider() expect(slider).toBeInTheDocument() }) diff --git a/web/app/components/datasets/settings/index-method/keyword-number.tsx b/web/app/components/datasets/settings/index-method/keyword-number.tsx index 95810d7d49..feb63c1d65 100644 --- a/web/app/components/datasets/settings/index-method/keyword-number.tsx +++ b/web/app/components/datasets/settings/index-method/keyword-number.tsx @@ -1,7 +1,6 @@ import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/base/slider' import Tooltip from '@/app/components/base/tooltip' import { NumberField, @@ -11,6 +10,7 @@ import { NumberFieldIncrement, NumberFieldInput, } from '@/app/components/base/ui/number-field' +import { Slider } from '@/app/components/base/ui/slider' const MIN_KEYWORD_NUMBER = 0 const MAX_KEYWORD_NUMBER = 50 @@ -47,7 +47,8 @@ const KeyWordNumber = ({ value={keywordNumber} min={MIN_KEYWORD_NUMBER} max={MAX_KEYWORD_NUMBER} - onChange={onKeywordNumberChange} + onValueChange={onKeywordNumberChange} + aria-label={t('form.numberOfKeywords', { ns: 'datasetSettings' })} /> { vi.clearAllMocks() vi.spyOn(console, 'error').mockImplementation(() => {}) vi.useFakeTimers({ shouldAdvanceTime: true }) - // jsdom does not implement scrollBy; mock it to prevent stderr noise + // The test DOM runtime does not implement scrollBy; mock it to prevent stderr noise window.scrollBy = vi.fn() }) diff --git a/web/app/components/develop/__tests__/use-doc-toc.spec.ts b/web/app/components/develop/__tests__/use-doc-toc.spec.ts index e437e13065..b20c2c8ecf 100644 --- a/web/app/components/develop/__tests__/use-doc-toc.spec.ts +++ b/web/app/components/develop/__tests__/use-doc-toc.spec.ts @@ -307,7 +307,7 @@ describe('useDocToc', () => { it('should update activeSection when scrolling past a section', async () => { vi.useFakeTimers() - // innerHeight/2 = 384 in jsdom (default 768), so top <= 384 means "scrolled past" + // innerHeight/2 = 384 with the default test viewport height (768), so top <= 384 means "scrolled past" const { scrollContainer, cleanup } = setupScrollDOM([ { id: 'intro', text: 'Intro', top: 100 }, { id: 'details', text: 'Details', top: 600 }, diff --git a/web/app/components/devtools/agentation-loader.tsx b/web/app/components/devtools/agentation-loader.tsx new file mode 100644 index 0000000000..87e1b44c87 --- /dev/null +++ b/web/app/components/devtools/agentation-loader.tsx @@ -0,0 +1,13 @@ +'use client' + +import { IS_DEV } from '@/config' +import dynamic from '@/next/dynamic' + +const Agentation = dynamic(() => import('agentation').then(module => module.Agentation), { ssr: false }) + +export function AgentationLoader() { + if (!IS_DEV) + return null + + return +} diff --git a/web/app/components/explore/banner/__tests__/banner-item.spec.tsx b/web/app/components/explore/banner/__tests__/banner-item.spec.tsx index de35814e8e..2d07cbddd8 100644 --- a/web/app/components/explore/banner/__tests__/banner-item.spec.tsx +++ b/web/app/components/explore/banner/__tests__/banner-item.spec.tsx @@ -1,3 +1,4 @@ +import type { ComponentProps } from 'react' import type { Banner } from '@/models/app' import { cleanup, fireEvent, render, screen } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' @@ -5,6 +6,11 @@ import { BannerItem } from '../banner-item' const mockScrollTo = vi.fn() const mockSlideNodes = vi.fn() +const mockTrackEvent = vi.fn() + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) vi.mock('@/app/components/base/carousel', () => ({ useCarousel: () => ({ @@ -48,19 +54,34 @@ class MockResizeObserver { } } +const renderBannerItem = ( + banner: Banner = createMockBanner(), + props: Partial> = {}, +) => { + return render( + , + ) +} + describe('BannerItem', () => { let mockWindowOpen: ReturnType beforeEach(() => { mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) - mockSlideNodes.mockReturnValue([{}, {}, {}]) // 3 slides + mockSlideNodes.mockReturnValue([{}, {}, {}]) vi.stubGlobal('ResizeObserver', MockResizeObserver) Object.defineProperty(window, 'innerWidth', { writable: true, configurable: true, - value: 1400, // Above RESPONSIVE_BREAKPOINT (1200) + value: 1400, }) }) @@ -73,81 +94,51 @@ describe('BannerItem', () => { describe('basic rendering', () => { it('renders banner category', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Featured')).toBeInTheDocument() }) it('renders banner title', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) it('renders banner description', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test banner description text')).toBeInTheDocument() }) it('renders banner image with correct src and alt', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() const image = screen.getByRole('img') expect(image).toHaveAttribute('src', 'https://example.com/image.png') expect(image).toHaveAttribute('alt', 'Test Banner Title') }) it('renders view more text', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('explore.banner.viewMore')).toBeInTheDocument() }) }) describe('click handling', () => { - it('opens banner link in new tab when clicked', () => { + it('opens banner link in new tab and tracks click when clicked', () => { const banner = createMockBanner({ link: 'https://test-link.com' }) - render( - , - ) + renderBannerItem(banner, { sort: 2, language: 'zh-Hans', accountId: 'account-123' }) const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') fireEvent.click(bannerElement!) + expect(mockTrackEvent).toHaveBeenCalledWith('explore_banner_click', expect.objectContaining({ + banner_id: 'banner-1', + title: 'Test Banner Title', + sort: 2, + link: 'https://test-link.com', + page: 'explore', + language: 'zh-Hans', + account_id: 'account-123', + event_time: expect.any(Number), + })) expect(mockWindowOpen).toHaveBeenCalledWith( 'https://test-link.com', '_blank', @@ -155,18 +146,16 @@ describe('BannerItem', () => { ) }) - it('does not open window when banner has no link', () => { + it('tracks click even when banner has no link', () => { const banner = createMockBanner({ link: '' }) - render( - , - ) + renderBannerItem(banner) const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') fireEvent.click(bannerElement!) + expect(mockTrackEvent).toHaveBeenCalledWith('explore_banner_click', expect.objectContaining({ + link: '', + })) expect(mockWindowOpen).not.toHaveBeenCalled() }) }) @@ -174,28 +163,13 @@ describe('BannerItem', () => { describe('slide indicators', () => { it('renders correct number of indicator buttons', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) - - const buttons = screen.getAllByRole('button') - expect(buttons).toHaveLength(3) + renderBannerItem() + expect(screen.getAllByRole('button')).toHaveLength(3) }) it('renders indicator buttons with correct numbers', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('01')).toBeInTheDocument() expect(screen.getByText('02')).toBeInTheDocument() expect(screen.getByText('03')).toBeInTheDocument() @@ -203,13 +177,7 @@ describe('BannerItem', () => { it('calls scrollTo when indicator is clicked', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) + renderBannerItem() const secondIndicator = screen.getByText('02').closest('button') fireEvent.click(secondIndicator!) @@ -219,81 +187,39 @@ describe('BannerItem', () => { it('renders no indicators when no slides', () => { mockSlideNodes.mockReturnValue([]) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.queryByRole('button')).not.toBeInTheDocument() }) }) describe('isPaused prop', () => { it('defaults isPaused to false', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) it('accepts isPaused prop', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem(createMockBanner(), { isPaused: true }) expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) }) describe('responsive behavior', () => { it('sets up ResizeObserver on mount', () => { - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(mockResizeObserverObserve).toHaveBeenCalled() }) it('adds resize event listener on mount', () => { const addEventListenerSpy = vi.spyOn(window, 'addEventListener') - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(addEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) addEventListenerSpy.mockRestore() }) it('removes resize event listener on unmount', () => { const removeEventListenerSpy = vi.spyOn(window, 'removeEventListener') - const banner = createMockBanner() - const { unmount } = render( - , - ) + const { unmount } = renderBannerItem() unmount() @@ -308,14 +234,7 @@ describe('BannerItem', () => { value: 1000, }) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('Test Banner Title')).toBeInTheDocument() }) @@ -326,14 +245,7 @@ describe('BannerItem', () => { value: 800, }) - const banner = createMockBanner() - render( - , - ) - + renderBannerItem() expect(screen.getByText('explore.banner.viewMore')).toBeInTheDocument() }) }) @@ -348,13 +260,8 @@ describe('BannerItem', () => { 'img-src': 'https://example.com/img.png', }, } as Partial) - render( - , - ) + renderBannerItem(banner) expect(screen.getByText('Very Long Category Name')).toBeInTheDocument() }) @@ -367,13 +274,8 @@ describe('BannerItem', () => { 'img-src': 'https://example.com/img.png', }, } as Partial) - render( - , - ) + renderBannerItem(banner) const titleElement = screen.getByText('A Very Long Title That Should Be Truncated Eventually') expect(titleElement).toHaveClass('line-clamp-2') }) @@ -387,13 +289,8 @@ describe('BannerItem', () => { 'img-src': 'https://example.com/img.png', }, } as Partial) - render( - , - ) + renderBannerItem(banner) const descriptionElement = screen.getByText(/A very long description/) expect(descriptionElement).toHaveClass('line-clamp-4') }) @@ -402,56 +299,26 @@ describe('BannerItem', () => { describe('slide calculation', () => { it('calculates next index correctly for first slide', () => { mockSlideNodes.mockReturnValue([{}, {}, {}]) - const banner = createMockBanner() - render( - , - ) - - const buttons = screen.getAllByRole('button') - expect(buttons).toHaveLength(3) + renderBannerItem() + expect(screen.getAllByRole('button')).toHaveLength(3) }) it('handles single slide case', () => { mockSlideNodes.mockReturnValue([{}]) - const banner = createMockBanner() - render( - , - ) - - const buttons = screen.getAllByRole('button') - expect(buttons).toHaveLength(1) + renderBannerItem() + expect(screen.getAllByRole('button')).toHaveLength(1) }) }) describe('wrapper styling', () => { it('has cursor-pointer class', () => { - const banner = createMockBanner() - const { container } = render( - , - ) - + const { container } = renderBannerItem() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('cursor-pointer') }) it('has rounded-2xl class', () => { - const banner = createMockBanner() - const { container } = render( - , - ) - + const { container } = renderBannerItem() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('rounded-2xl') }) diff --git a/web/app/components/explore/banner/__tests__/banner.spec.tsx b/web/app/components/explore/banner/__tests__/banner.spec.tsx index d6d0aa44a8..069aaf02dc 100644 --- a/web/app/components/explore/banner/__tests__/banner.spec.tsx +++ b/web/app/components/explore/banner/__tests__/banner.spec.tsx @@ -6,6 +6,8 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import Banner from '../banner' const mockUseGetBanners = vi.fn() +const mockUseSelector = vi.fn() +const mockTrackEvent = vi.fn() vi.mock('@/service/use-explore', () => ({ useGetBanners: (...args: unknown[]) => mockUseGetBanners(...args), @@ -15,6 +17,14 @@ vi.mock('@/context/i18n', () => ({ useLocale: () => 'en-US', })) +vi.mock('@/context/app-context', () => ({ + useSelector: (...args: unknown[]) => mockUseSelector(...args), +})) + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + vi.mock('@/app/components/base/carousel', () => ({ Carousel: Object.assign( ({ children, onMouseEnter, onMouseLeave, className }: { @@ -54,9 +64,12 @@ vi.mock('@/app/components/base/carousel', () => ({ })) vi.mock('../banner-item', () => ({ - BannerItem: ({ banner, autoplayDelay, isPaused }: { + BannerItem: ({ banner, autoplayDelay, isPaused, sort, language, accountId }: { banner: BannerType autoplayDelay: number + sort: number + language: string + accountId?: string isPaused?: boolean }) => (
({ data-banner-id={banner.id} data-autoplay-delay={autoplayDelay} data-is-paused={isPaused} + data-sort={sort} + data-language={language} + data-account-id={accountId} > BannerItem: {' '} @@ -87,6 +103,11 @@ const createMockBanner = (id: string, status: string = 'enabled', title: string describe('Banner', () => { beforeEach(() => { vi.useFakeTimers() + mockUseSelector.mockImplementation(selector => selector({ + userProfile: { + id: 'account-123', + }, + })) }) afterEach(() => { @@ -235,6 +256,59 @@ describe('Banner', () => { expect(screen.getByTestId('carousel')).toHaveClass('rounded-2xl') }) + + it('tracks enabled banner impressions with expected payload', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'Enabled Banner 1'), + createMockBanner('2', 'disabled', 'Disabled Banner'), + createMockBanner('3', 'enabled', 'Enabled Banner 2'), + ], + isLoading: false, + isError: false, + }) + + render() + + expect(mockTrackEvent).toHaveBeenCalledTimes(2) + expect(mockTrackEvent).toHaveBeenNthCalledWith(1, 'explore_banner_impression', expect.objectContaining({ + banner_id: '1', + title: 'Enabled Banner 1', + sort: 1, + link: 'https://example.com', + page: 'explore', + language: 'en-US', + account_id: 'account-123', + event_time: expect.any(Number), + })) + expect(mockTrackEvent).toHaveBeenNthCalledWith(2, 'explore_banner_impression', expect.objectContaining({ + banner_id: '3', + title: 'Enabled Banner 2', + sort: 2, + link: 'https://example.com', + page: 'explore', + language: 'en-US', + account_id: 'account-123', + event_time: expect.any(Number), + })) + }) + + it('does not track impressions when account id is unavailable', () => { + mockUseSelector.mockImplementation(selector => selector({ + userProfile: { + id: '', + }, + })) + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled', 'Enabled Banner 1')], + isLoading: false, + isError: false, + }) + + render() + + expect(mockTrackEvent).not.toHaveBeenCalled() + }) }) describe('hover behavior', () => { @@ -435,8 +509,25 @@ describe('Banner', () => { const bannerItems = screen.getAllByTestId('banner-item') expect(bannerItems[0]).toHaveAttribute('data-banner-id', '1') + expect(bannerItems[0]).toHaveAttribute('data-sort', '1') expect(bannerItems[1]).toHaveAttribute('data-banner-id', '2') + expect(bannerItems[1]).toHaveAttribute('data-sort', '2') expect(bannerItems[2]).toHaveAttribute('data-banner-id', '3') + expect(bannerItems[2]).toHaveAttribute('data-sort', '3') + }) + + it('passes tracking context to banner item', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled', 'Banner 1')], + isLoading: false, + isError: false, + }) + + render() + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-language', 'en-US') + expect(bannerItem).toHaveAttribute('data-account-id', 'account-123') }) }) diff --git a/web/app/components/explore/banner/banner-item.tsx b/web/app/components/explore/banner/banner-item.tsx index d90a1060f9..c1e48bf420 100644 --- a/web/app/components/explore/banner/banner-item.tsx +++ b/web/app/components/explore/banner/banner-item.tsx @@ -4,6 +4,7 @@ import type { Banner } from '@/models/app' import { RiArrowRightLine } from '@remixicon/react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { trackEvent } from '@/app/components/base/amplitude' import { useCarousel } from '@/app/components/base/carousel' import { cn } from '@/utils/classnames' import { IndicatorButton } from './indicator-button' @@ -11,6 +12,9 @@ import { IndicatorButton } from './indicator-button' type BannerItemProps = { banner: Banner autoplayDelay: number + sort: number + language: string + accountId?: string isPaused?: boolean } @@ -20,7 +24,14 @@ const INDICATOR_WIDTH = 20 const INDICATOR_GAP = 8 const MIN_VIEW_MORE_WIDTH = 480 -export const BannerItem: FC = ({ banner, autoplayDelay, isPaused = false }) => { +export const BannerItem: FC = ({ + banner, + autoplayDelay, + sort, + language, + accountId, + isPaused = false, +}) => { const { t } = useTranslation() const { api, selectedIndex } = useCarousel() const { category, title, description, 'img-src': imgSrc } = banner.content @@ -91,9 +102,21 @@ export const BannerItem: FC = ({ banner, autoplayDelay, isPause const handleBannerClick = useCallback(() => { incrementResetKey() + + trackEvent('explore_banner_click', { + banner_id: banner.id, + title: banner.content.title, + sort, + link: banner.link, + page: 'explore', + language, + account_id: accountId, + event_time: Date.now(), + }) + if (banner.link) window.open(banner.link, '_blank', 'noopener,noreferrer') - }, [banner.link, incrementResetKey]) + }, [accountId, banner, incrementResetKey, language, sort]) const handleIndicatorClick = useCallback((index: number) => { incrementResetKey() diff --git a/web/app/components/explore/banner/banner.tsx b/web/app/components/explore/banner/banner.tsx index 4ec0efdb2b..a320bb1974 100644 --- a/web/app/components/explore/banner/banner.tsx +++ b/web/app/components/explore/banner/banner.tsx @@ -1,7 +1,9 @@ import type { FC } from 'react' import * as React from 'react' import { useEffect, useMemo, useRef, useState } from 'react' +import { trackEvent } from '@/app/components/base/amplitude' import { Carousel } from '@/app/components/base/carousel' +import { useSelector } from '@/context/app-context' import { useLocale } from '@/context/i18n' import { useGetBanners } from '@/service/use-explore' import Loading from '../../base/loading' @@ -23,9 +25,11 @@ const LoadingState: FC = () => ( const Banner: FC = () => { const locale = useLocale() const { data: banners, isLoading, isError } = useGetBanners(locale) + const accountId = useSelector(s => s.userProfile.id) const [isHovered, setIsHovered] = useState(false) const [isResizing, setIsResizing] = useState(false) const resizeTimerRef = useRef(null) + const trackedBannerIdsRef = useRef>(new Set()) const enabledBanners = useMemo( () => banners?.filter(banner => banner.status === 'enabled') ?? [], @@ -56,6 +60,28 @@ const Banner: FC = () => { } }, []) + useEffect(() => { + if (!accountId) + return + + enabledBanners.forEach((banner, index) => { + if (trackedBannerIdsRef.current.has(banner.id)) + return + + trackEvent('explore_banner_impression', { + banner_id: banner.id, + title: banner.content.title, + sort: index + 1, + link: banner.link, + page: 'explore', + language: locale, + account_id: accountId, + event_time: Date.now(), + }) + trackedBannerIdsRef.current.add(banner.id) + }) + }, [accountId, enabledBanners, locale]) + if (isLoading) return @@ -77,12 +103,15 @@ const Banner: FC = () => { onMouseLeave={() => setIsHovered(false)} > - {enabledBanners.map(banner => ( + {enabledBanners.map((banner, index) => ( ))} diff --git a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx index eb4d543e66..9d4226c33a 100644 --- a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx @@ -69,6 +69,7 @@ vi.mock('@/context/i18n', () => ({ const { mockConfig, mockEnv } = vi.hoisted(() => ({ mockConfig: { IS_CLOUD_EDITION: false, + AMPLITUDE_API_KEY: '', ZENDESK_WIDGET_KEY: '', SUPPORT_EMAIL_ADDRESS: '', }, @@ -80,6 +81,8 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ })) vi.mock('@/config', () => ({ get IS_CLOUD_EDITION() { return mockConfig.IS_CLOUD_EDITION }, + get AMPLITUDE_API_KEY() { return mockConfig.AMPLITUDE_API_KEY }, + get isAmplitudeEnabled() { return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY }, get ZENDESK_WIDGET_KEY() { return mockConfig.ZENDESK_WIDGET_KEY }, get SUPPORT_EMAIL_ADDRESS() { return mockConfig.SUPPORT_EMAIL_ADDRESS }, IS_DEV: false, diff --git a/web/app/components/header/account-setting/__tests__/index.spec.tsx b/web/app/components/header/account-setting/__tests__/index.spec.tsx index 2aa9db4771..279af0b114 100644 --- a/web/app/components/header/account-setting/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/index.spec.tsx @@ -315,14 +315,14 @@ describe('AccountSetting', () => { it('should handle scroll event in panel', () => { // Act renderAccountSetting() - const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto') + const scrollContainer = screen.getByRole('dialog').querySelector('.overscroll-contain') // Assert expect(scrollContainer).toBeInTheDocument() if (scrollContainer) { // Scroll down fireEvent.scroll(scrollContainer, { target: { scrollTop: 100 } }) - expect(scrollContainer).toHaveClass('overflow-y-auto') + expect(scrollContainer).toHaveClass('overscroll-contain') // Scroll back up fireEvent.scroll(scrollContainer, { target: { scrollTop: 0 } }) diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx index efe6c46dcc..5f1492f14a 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx @@ -78,6 +78,7 @@ const ApiBasedExtensionModal: FC = ({
diff --git a/web/app/components/header/account-setting/api-based-extension-page/selector.tsx b/web/app/components/header/account-setting/api-based-extension-page/selector.tsx index 38acb73154..62052aece6 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/selector.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/selector.tsx @@ -69,7 +69,7 @@ const ApiBasedExtensionSelector: FC = ({ ) } - +
diff --git a/web/app/components/header/account-setting/data-source-page-new/configure.tsx b/web/app/components/header/account-setting/data-source-page-new/configure.tsx index a3dba783e1..484338d333 100644 --- a/web/app/components/header/account-setting/data-source-page-new/configure.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/configure.tsx @@ -84,7 +84,7 @@ const Configure = ({ {t('dataSource.configure', { ns: 'common' })} - +
{ !!canOAuth && ( @@ -104,7 +104,7 @@ const Configure = ({ } { !!canApiKey && !!canOAuth && ( -
+
OR
diff --git a/web/app/components/header/account-setting/data-source-page-new/operator.tsx b/web/app/components/header/account-setting/data-source-page-new/operator.tsx index 14bdee4fd0..c5b2a948de 100644 --- a/web/app/components/header/account-setting/data-source-page-new/operator.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/operator.tsx @@ -39,7 +39,7 @@ const Operator = ({ text: (
-
{t('auth.setDefault', { ns: 'plugin' })}
+
{t('auth.setDefault', { ns: 'plugin' })}
), }, @@ -51,7 +51,7 @@ const Operator = ({ text: (
-
{t('operation.rename', { ns: 'common' })}
+
{t('operation.rename', { ns: 'common' })}
), }, @@ -66,7 +66,7 @@ const Operator = ({ text: (
-
{t('operation.edit', { ns: 'common' })}
+
{t('operation.edit', { ns: 'common' })}
), }, @@ -81,7 +81,7 @@ const Operator = ({ text: (
-
{t('dataSource.notion.changeAuthorizedPages', { ns: 'common' })}
+
{t('dataSource.notion.changeAuthorizedPages', { ns: 'common' })}
), }, @@ -98,7 +98,7 @@ const Operator = ({ text: (
-
+
{t('operation.remove', { ns: 'common' })}
@@ -122,7 +122,7 @@ const Operator = ({ items={items} secondItems={secondItems} onSelect={handleSelect} - popupClassName="z-[61]" + popupClassName="z-[1002]" triggerProps={{ size: 'l', }} diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/__tests__/index.spec.tsx deleted file mode 100644 index dad82d81b9..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/__tests__/index.spec.tsx +++ /dev/null @@ -1,462 +0,0 @@ -import type { UseQueryResult } from '@tanstack/react-query' -import type { AppContextValue } from '@/context/app-context' -import type { DataSourceNotion as TDataSourceNotion } from '@/models/common' -import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' -import { useAppContext } from '@/context/app-context' -import { useDataSourceIntegrates, useInvalidDataSourceIntegrates, useNotionConnection } from '@/service/use-common' -import DataSourceNotion from '../index' - -/** - * DataSourceNotion Component Tests - * Using Unit approach with real Panel and sibling components to test Notion integration logic. - */ - -type MockQueryResult = UseQueryResult - -// Mock dependencies -vi.mock('@/context/app-context', () => ({ - useAppContext: vi.fn(), -})) - -vi.mock('@/service/common', () => ({ - syncDataSourceNotion: vi.fn(), - updateDataSourceNotionAction: vi.fn(), -})) - -vi.mock('@/service/use-common', () => ({ - useDataSourceIntegrates: vi.fn(), - useNotionConnection: vi.fn(), - useInvalidDataSourceIntegrates: vi.fn(), -})) - -describe('DataSourceNotion Component', () => { - const mockWorkspaces: TDataSourceNotion[] = [ - { - id: 'ws-1', - provider: 'notion', - is_bound: true, - source_info: { - workspace_name: 'Workspace 1', - workspace_icon: 'https://example.com/icon-1.png', - workspace_id: 'notion-ws-1', - total: 10, - pages: [], - }, - }, - ] - - const baseAppContext: AppContextValue = { - userProfile: { id: 'test-user-id', name: 'test-user', email: 'test@example.com', avatar: '', avatar_url: '', is_password_set: true }, - mutateUserProfile: vi.fn(), - currentWorkspace: { id: 'ws-id', name: 'Workspace', plan: 'basic', status: 'normal', created_at: 0, role: 'owner', providers: [], trial_credits: 0, trial_credits_used: 0, next_credit_reset_date: 0 }, - isCurrentWorkspaceManager: true, - isCurrentWorkspaceOwner: true, - isCurrentWorkspaceEditor: true, - isCurrentWorkspaceDatasetOperator: false, - mutateCurrentWorkspace: vi.fn(), - langGeniusVersionInfo: { current_version: '0.1.0', latest_version: '0.1.1', version: '0.1.1', release_date: '', release_notes: '', can_auto_update: false, current_env: 'test' }, - useSelector: vi.fn(), - isLoadingCurrentWorkspace: false, - isValidatingCurrentWorkspace: false, - } - - /* eslint-disable-next-line ts/no-explicit-any */ - const mockQuerySuccess = (data: T): MockQueryResult => ({ data, isSuccess: true, isError: false, isLoading: false, isPending: false, status: 'success', error: null, fetchStatus: 'idle' } as any) - /* eslint-disable-next-line ts/no-explicit-any */ - const mockQueryPending = (): MockQueryResult => ({ data: undefined, isSuccess: false, isError: false, isLoading: true, isPending: true, status: 'pending', error: null, fetchStatus: 'fetching' } as any) - - const originalLocation = window.location - - beforeEach(() => { - vi.clearAllMocks() - vi.mocked(useAppContext).mockReturnValue(baseAppContext) - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: [] })) - vi.mocked(useNotionConnection).mockReturnValue(mockQueryPending()) - vi.mocked(useInvalidDataSourceIntegrates).mockReturnValue(vi.fn()) - - const locationMock = { href: '', assign: vi.fn() } - Object.defineProperty(window, 'location', { value: locationMock, writable: true, configurable: true }) - - // Clear document body to avoid toast leaks between tests - document.body.innerHTML = '' - }) - - afterEach(() => { - Object.defineProperty(window, 'location', { value: originalLocation, writable: true, configurable: true }) - }) - - const getWorkspaceItem = (name: string) => { - const nameEl = screen.getByText(name) - return (nameEl.closest('div[class*="workspace-item"]') || nameEl.parentElement) as HTMLElement - } - - describe('Rendering', () => { - it('should render with no workspaces initially and call integration hook', () => { - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.title')).toBeInTheDocument() - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: undefined }) - }) - - it('should render with provided workspaces and pass initialData to hook', () => { - // Arrange - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.connectedWorkspace')).toBeInTheDocument() - expect(screen.getByText('Workspace 1')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.connected')).toBeInTheDocument() - expect(screen.getByAltText('workspace icon')).toHaveAttribute('src', 'https://example.com/icon-1.png') - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: { data: mockWorkspaces } }) - }) - - it('should handle workspaces prop being an empty array', () => { - // Act - render() - - // Assert - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: { data: [] } }) - }) - - it('should handle optional workspaces configurations', () => { - // Branch: workspaces passed as undefined - const { rerender } = render() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: undefined }) - - // Branch: workspaces passed as null - /* eslint-disable-next-line ts/no-explicit-any */ - rerender() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: undefined }) - - // Branch: workspaces passed as [] - rerender() - expect(useDataSourceIntegrates).toHaveBeenCalledWith({ initialData: { data: [] } }) - }) - - it('should handle cases where integrates data is loading or broken', () => { - // Act (Loading) - const { rerender } = render() - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQueryPending()) - rerender() - // Assert - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - - // Act (Broken) - const brokenData = {} as { data: TDataSourceNotion[] } - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess(brokenData)) - rerender() - // Assert - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - }) - - it('should handle integrates being nullish', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: undefined, isSuccess: true } as any) - render() - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - }) - - it('should handle integrates data being nullish', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: { data: null }, isSuccess: true } as any) - render() - expect(screen.queryByText('common.dataSource.notion.connectedWorkspace')).not.toBeInTheDocument() - }) - - it('should handle integrates data being valid', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: { data: [{ id: '1', is_bound: true, source_info: { workspace_name: 'W', workspace_icon: 'https://example.com/i.png', total: 1, pages: [] } }] }, isSuccess: true } as any) - render() - expect(screen.getByText('common.dataSource.notion.connectedWorkspace')).toBeInTheDocument() - }) - - it('should cover all possible falsy/nullish branches for integrates and workspaces', () => { - /* eslint-disable-next-line ts/no-explicit-any */ - const { rerender } = render() - - const integratesCases = [ - undefined, - null, - {}, - { data: null }, - { data: undefined }, - { data: [] }, - { data: [mockWorkspaces[0]] }, - { data: false }, - { data: 0 }, - { data: '' }, - 123, - 'string', - false, - ] - - integratesCases.forEach((val) => { - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useDataSourceIntegrates).mockReturnValue({ data: val, isSuccess: true } as any) - /* eslint-disable-next-line ts/no-explicit-any */ - rerender() - }) - - expect(useDataSourceIntegrates).toHaveBeenCalled() - }) - }) - - describe('User Permissions', () => { - it('should pass readOnly as false when user is a manager', () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ ...baseAppContext, isCurrentWorkspaceManager: true }) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.title').closest('div')).not.toHaveClass('grayscale') - }) - - it('should pass readOnly as true when user is NOT a manager', () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ ...baseAppContext, isCurrentWorkspaceManager: false }) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.connect')).toHaveClass('opacity-50', 'grayscale') - }) - }) - - describe('Configure and Auth Actions', () => { - it('should handle configure action when user is workspace manager', () => { - // Arrange - render() - - // Act - fireEvent.click(screen.getByText('common.dataSource.connect')) - - // Assert - expect(useNotionConnection).toHaveBeenCalledWith(true) - }) - - it('should block configure action when user is NOT workspace manager', () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ ...baseAppContext, isCurrentWorkspaceManager: false }) - render() - - // Act - fireEvent.click(screen.getByText('common.dataSource.connect')) - - // Assert - expect(useNotionConnection).toHaveBeenCalledWith(false) - }) - - it('should redirect if auth URL is available when "Auth Again" is clicked', async () => { - // Arrange - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'http://auth-url' })) - render() - - // Act - const workspaceItem = getWorkspaceItem('Workspace 1') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - fireEvent.click(authAgainBtn) - - // Assert - expect(window.location.href).toBe('http://auth-url') - }) - - it('should trigger connection flow if URL is missing when "Auth Again" is clicked', async () => { - // Arrange - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - render() - - // Act - const workspaceItem = getWorkspaceItem('Workspace 1') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - fireEvent.click(authAgainBtn) - - // Assert - expect(useNotionConnection).toHaveBeenCalledWith(true) - }) - }) - - describe('Side Effects (Redirection and Toast)', () => { - it('should redirect automatically when connection data returns an http URL', async () => { - // Arrange - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'http://redirect-url' })) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('http://redirect-url') - }) - }) - - it('should show toast notification when connection data is "internal"', async () => { - // Arrange - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'internal' })) - - // Act - render() - - // Assert - expect(await screen.findByText('common.dataSource.notion.integratedAlert')).toBeInTheDocument() - }) - - it('should handle various data types and missing properties in connection data correctly', async () => { - // Arrange & Act (Unknown string) - const { rerender } = render() - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'unknown' })) - rerender() - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - expect(screen.queryByText('common.dataSource.notion.integratedAlert')).not.toBeInTheDocument() - }) - - // Act (Broken object) - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({} as any)) - rerender() - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - - // Act (Non-string) - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 123 } as any)) - rerender() - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - }) - - it('should redirect if data starts with "http" even if it is just "http"', async () => { - // Arrange - vi.mocked(useNotionConnection).mockReturnValue(mockQuerySuccess({ data: 'http' })) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('http') - }) - }) - - it('should skip side effect logic if connection data is an object but missing the "data" property', async () => { - // Arrange - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue({} as any) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - }) - - it('should skip side effect logic if data.data is falsy', async () => { - // Arrange - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue({ data: { data: null } } as any) - - // Act - render() - - // Assert - await waitFor(() => { - expect(window.location.href).toBe('') - }) - }) - }) - - describe('Additional Action Edge Cases', () => { - it.each([ - undefined, - null, - {}, - { data: undefined }, - { data: null }, - { data: '' }, - { data: 0 }, - { data: false }, - { data: 'http' }, - { data: 'internal' }, - { data: 'unknown' }, - ])('should cover connection data branch: %s', async (val) => { - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces })) - /* eslint-disable-next-line ts/no-explicit-any */ - vi.mocked(useNotionConnection).mockReturnValue({ data: val, isSuccess: true } as any) - - render() - - // Trigger handleAuthAgain with these values - const workspaceItem = getWorkspaceItem('Workspace 1') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - fireEvent.click(authAgainBtn) - - expect(useNotionConnection).toHaveBeenCalled() - }) - }) - - describe('Edge Cases in Workspace Data', () => { - it('should render correctly with missing source_info optional fields', async () => { - // Arrange - const workspaceWithMissingInfo: TDataSourceNotion = { - id: 'ws-2', - provider: 'notion', - is_bound: false, - source_info: { workspace_name: 'Workspace 2', workspace_id: 'notion-ws-2', workspace_icon: null, pages: [] }, - } - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: [workspaceWithMissingInfo] })) - - // Act - render() - - // Assert - expect(screen.getByText('Workspace 2')).toBeInTheDocument() - - const workspaceItem = getWorkspaceItem('Workspace 2') - const actionBtn = within(workspaceItem).getByRole('button') - fireEvent.click(actionBtn) - - expect(await screen.findByText('0 common.dataSource.notion.pagesAuthorized')).toBeInTheDocument() - }) - - it('should display inactive status correctly for unbound workspaces', () => { - // Arrange - const inactiveWS: TDataSourceNotion = { - id: 'ws-3', - provider: 'notion', - is_bound: false, - source_info: { workspace_name: 'Workspace 3', workspace_icon: 'https://example.com/icon-3.png', workspace_id: 'notion-ws-3', total: 5, pages: [] }, - } - vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: [inactiveWS] })) - - // Act - render() - - // Assert - expect(screen.getByText('common.dataSource.notion.disconnected')).toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx deleted file mode 100644 index 0959383f29..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx +++ /dev/null @@ -1,103 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { DataSourceNotion as TDataSourceNotion } from '@/models/common' -import { noop } from 'es-toolkit/function' -import * as React from 'react' -import { useEffect, useState } from 'react' -import { useTranslation } from 'react-i18next' -import NotionIcon from '@/app/components/base/notion-icon' -import Toast from '@/app/components/base/toast' -import { useAppContext } from '@/context/app-context' -import { useDataSourceIntegrates, useNotionConnection } from '@/service/use-common' -import Panel from '../panel' -import { DataSourceType } from '../panel/types' - -const Icon: FC<{ - src: string - name: string - className: string -}> = ({ src, name, className }) => { - return ( - - ) -} -type Props = { - workspaces?: TDataSourceNotion[] -} - -const DataSourceNotion: FC = ({ - workspaces, -}) => { - const { isCurrentWorkspaceManager } = useAppContext() - const [canConnectNotion, setCanConnectNotion] = useState(false) - const { data: integrates } = useDataSourceIntegrates({ - initialData: workspaces ? { data: workspaces } : undefined, - }) - const { data } = useNotionConnection(canConnectNotion) - const { t } = useTranslation() - - const resolvedWorkspaces = integrates?.data ?? [] - const connected = !!resolvedWorkspaces.length - - const handleConnectNotion = () => { - if (!isCurrentWorkspaceManager) - return - - setCanConnectNotion(true) - } - - const handleAuthAgain = () => { - if (data?.data) - window.location.href = data.data - else - setCanConnectNotion(true) - } - - useEffect(() => { - if (data && 'data' in data) { - if (data.data && typeof data.data === 'string' && data.data.startsWith('http')) { - window.location.href = data.data - } - else if (data.data === 'internal') { - Toast.notify({ - type: 'info', - message: t('dataSource.notion.integratedAlert', { ns: 'common' }), - }) - } - } - }, [data, t]) - - return ( - ({ - id: workspace.id, - logo: ({ className }: { className: string }) => ( - - ), - name: workspace.source_info.workspace_name, - isActive: workspace.is_bound, - notionConfig: { - total: workspace.source_info.total || 0, - }, - }))} - onRemove={noop} // handled in operation/index.tsx - notionActions={{ - onChangeAuthorizedPage: handleAuthAgain, - }} - /> - ) -} -export default React.memo(DataSourceNotion) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/__tests__/index.spec.tsx deleted file mode 100644 index f433b10020..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/__tests__/index.spec.tsx +++ /dev/null @@ -1,137 +0,0 @@ -import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' -import { syncDataSourceNotion, updateDataSourceNotionAction } from '@/service/common' -import { useInvalidDataSourceIntegrates } from '@/service/use-common' -import Operate from '../index' - -/** - * Operate Component (Notion) Tests - * This component provides actions like Sync, Change Pages, and Remove for Notion data sources. - */ - -// Mock services and toast -vi.mock('@/service/common', () => ({ - syncDataSourceNotion: vi.fn(), - updateDataSourceNotionAction: vi.fn(), -})) - -vi.mock('@/service/use-common', () => ({ - useInvalidDataSourceIntegrates: vi.fn(), -})) - -describe('Operate Component (Notion)', () => { - const mockPayload = { - id: 'test-notion-id', - total: 5, - } - const mockOnAuthAgain = vi.fn() - const mockInvalidate = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - vi.mocked(useInvalidDataSourceIntegrates).mockReturnValue(mockInvalidate) - vi.mocked(syncDataSourceNotion).mockResolvedValue({ result: 'success' }) - vi.mocked(updateDataSourceNotionAction).mockResolvedValue({ result: 'success' }) - }) - - describe('Rendering', () => { - it('should render the menu button initially', () => { - // Act - const { container } = render() - - // Assert - const menuButton = within(container).getByRole('button') - expect(menuButton).toBeInTheDocument() - expect(menuButton).not.toHaveClass('bg-state-base-hover') - }) - - it('should open the menu and show all options when clicked', async () => { - // Arrange - const { container } = render() - const menuButton = within(container).getByRole('button') - - // Act - fireEvent.click(menuButton) - - // Assert - expect(await screen.findByText('common.dataSource.notion.changeAuthorizedPages')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.sync')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.remove')).toBeInTheDocument() - expect(screen.getByText(/5/)).toBeInTheDocument() - expect(screen.getByText(/common.dataSource.notion.pagesAuthorized/)).toBeInTheDocument() - expect(menuButton).toHaveClass('bg-state-base-hover') - }) - }) - - describe('Menu Actions', () => { - it('should call onAuthAgain when Change Authorized Pages is clicked', async () => { - // Arrange - const { container } = render() - fireEvent.click(within(container).getByRole('button')) - const option = await screen.findByText('common.dataSource.notion.changeAuthorizedPages') - - // Act - fireEvent.click(option) - - // Assert - expect(mockOnAuthAgain).toHaveBeenCalledTimes(1) - }) - - it('should call handleSync, show success toast, and invalidate cache when Sync is clicked', async () => { - // Arrange - const { container } = render() - fireEvent.click(within(container).getByRole('button')) - const syncBtn = await screen.findByText('common.dataSource.notion.sync') - - // Act - fireEvent.click(syncBtn) - - // Assert - await waitFor(() => { - expect(syncDataSourceNotion).toHaveBeenCalledWith({ - url: `/oauth/data-source/notion/${mockPayload.id}/sync`, - }) - }) - expect((await screen.findAllByText('common.api.success')).length).toBeGreaterThan(0) - expect(mockInvalidate).toHaveBeenCalledTimes(1) - }) - - it('should call handleRemove, show success toast, and invalidate cache when Remove is clicked', async () => { - // Arrange - const { container } = render() - fireEvent.click(within(container).getByRole('button')) - const removeBtn = await screen.findByText('common.dataSource.notion.remove') - - // Act - fireEvent.click(removeBtn) - - // Assert - await waitFor(() => { - expect(updateDataSourceNotionAction).toHaveBeenCalledWith({ - url: `/data-source/integrates/${mockPayload.id}/disable`, - }) - }) - expect((await screen.findAllByText('common.api.success')).length).toBeGreaterThan(0) - expect(mockInvalidate).toHaveBeenCalledTimes(1) - }) - }) - - describe('State Transitions', () => { - it('should toggle the open class on the button based on menu visibility', async () => { - // Arrange - const { container } = render() - const menuButton = within(container).getByRole('button') - - // Act (Open) - fireEvent.click(menuButton) - // Assert - expect(menuButton).toHaveClass('bg-state-base-hover') - - // Act (Close - click again) - fireEvent.click(menuButton) - // Assert - await waitFor(() => { - expect(menuButton).not.toHaveClass('bg-state-base-hover') - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx deleted file mode 100644 index 043eb3c846..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx +++ /dev/null @@ -1,103 +0,0 @@ -'use client' -import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' -import { - RiDeleteBinLine, - RiLoopLeftLine, - RiMoreFill, - RiStickyNoteAddLine, -} from '@remixicon/react' -import { Fragment } from 'react' -import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' -import { syncDataSourceNotion, updateDataSourceNotionAction } from '@/service/common' -import { useInvalidDataSourceIntegrates } from '@/service/use-common' -import { cn } from '@/utils/classnames' - -type OperateProps = { - payload: { - id: string - total: number - } - onAuthAgain: () => void -} -export default function Operate({ - payload, - onAuthAgain, -}: OperateProps) { - const { t } = useTranslation() - const invalidateDataSourceIntegrates = useInvalidDataSourceIntegrates() - - const updateIntegrates = () => { - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - invalidateDataSourceIntegrates() - } - const handleSync = async () => { - await syncDataSourceNotion({ url: `/oauth/data-source/notion/${payload.id}/sync` }) - updateIntegrates() - } - const handleRemove = async () => { - await updateDataSourceNotionAction({ url: `/data-source/integrates/${payload.id}/disable` }) - updateIntegrates() - } - - return ( - - { - ({ open }) => ( - <> - - - - - -
- -
- -
-
{t('dataSource.notion.changeAuthorizedPages', { ns: 'common' })}
-
- {payload.total} - {' '} - {t('dataSource.notion.pagesAuthorized', { ns: 'common' })} -
-
-
-
- -
- -
{t('dataSource.notion.sync', { ns: 'common' })}
-
-
-
- -
-
- -
{t('dataSource.notion.remove', { ns: 'common' })}
-
-
-
-
-
- - ) - } -
- ) -} diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-firecrawl-modal.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-firecrawl-modal.spec.tsx deleted file mode 100644 index dadda4a349..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-firecrawl-modal.spec.tsx +++ /dev/null @@ -1,204 +0,0 @@ -import type { CommonResponse } from '@/models/common' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { createDataSourceApiKeyBinding } from '@/service/datasets' -import ConfigFirecrawlModal from '../config-firecrawl-modal' - -/** - * ConfigFirecrawlModal Component Tests - * Tests validation, save logic, and basic rendering for the Firecrawl configuration modal. - */ - -vi.mock('@/service/datasets', () => ({ - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('ConfigFirecrawlModal Component', () => { - const mockOnCancel = vi.fn() - const mockOnSaved = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Initial Rendering', () => { - it('should render the modal with all fields and buttons', () => { - // Act - render() - - // Assert - expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument() - expect(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder')).toBeInTheDocument() - expect(screen.getByPlaceholderText('https://api.firecrawl.dev')).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.save/i })).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.cancel/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /datasetCreation\.firecrawl\.getApiKeyLinkText/i })).toHaveAttribute('href', 'https://www.firecrawl.dev/account') - }) - }) - - describe('Form Interactions', () => { - it('should update state when input fields change', async () => { - // Arrange - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder') - const baseUrlInput = screen.getByPlaceholderText('https://api.firecrawl.dev') - - // Act - fireEvent.change(apiKeyInput, { target: { value: 'firecrawl-key' } }) - fireEvent.change(baseUrlInput, { target: { value: 'https://custom.firecrawl.dev' } }) - - // Assert - expect(apiKeyInput).toHaveValue('firecrawl-key') - expect(baseUrlInput).toHaveValue('https://custom.firecrawl.dev') - }) - - it('should call onCancel when cancel button is clicked', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - expect(mockOnCancel).toHaveBeenCalled() - }) - }) - - describe('Validation', () => { - it('should show error when saving without API Key', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.fieldRequired:{"field":"API Key"}')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - - it('should show error for invalid Base URL format', async () => { - const user = userEvent.setup() - // Arrange - render() - const baseUrlInput = screen.getByPlaceholderText('https://api.firecrawl.dev') - - // Act - await user.type(baseUrlInput, 'ftp://invalid-url.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.urlError')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Saving Logic', () => { - it('should save successfully with valid API Key and custom URL', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'valid-key') - await user.type(screen.getByPlaceholderText('https://api.firecrawl.dev'), 'http://my-firecrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith({ - category: 'website', - provider: 'firecrawl', - credentials: { - auth_type: 'bearer', - config: { - api_key: 'valid-key', - base_url: 'http://my-firecrawl.com', - }, - }, - }) - }) - await waitFor(() => { - expect(screen.getByText('common.api.success')).toBeInTheDocument() - expect(mockOnSaved).toHaveBeenCalled() - }) - }) - - it('should use default Base URL if none is provided during save', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'test-key') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://api.firecrawl.dev', - }), - }), - })) - }) - }) - - it('should ignore multiple save clicks while saving is in progress', async () => { - const user = userEvent.setup() - // Arrange - let resolveSave: (value: CommonResponse) => void - const savePromise = new Promise((resolve) => { - resolveSave = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(savePromise) - render() - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'test-key') - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - - // Act - await user.click(saveBtn) - await user.click(saveBtn) - - // Assert - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveSave!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1)) - }) - - it('should accept base_url starting with https://', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder'), 'test-key') - await user.type(screen.getByPlaceholderText('https://api.firecrawl.dev'), 'https://secure-firecrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://secure-firecrawl.com', - }), - }), - })) - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-jina-reader-modal.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-jina-reader-modal.spec.tsx deleted file mode 100644 index 26c53993c1..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-jina-reader-modal.spec.tsx +++ /dev/null @@ -1,179 +0,0 @@ -import { render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { DataSourceProvider } from '@/models/common' -import { createDataSourceApiKeyBinding } from '@/service/datasets' -import ConfigJinaReaderModal from '../config-jina-reader-modal' - -/** - * ConfigJinaReaderModal Component Tests - * Tests validation, save logic, and basic rendering for the Jina Reader configuration modal. - */ - -vi.mock('@/service/datasets', () => ({ - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('ConfigJinaReaderModal Component', () => { - const mockOnCancel = vi.fn() - const mockOnSaved = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Initial Rendering', () => { - it('should render the modal with API Key field and buttons', () => { - // Act - render() - - // Assert - expect(screen.getByText('datasetCreation.jinaReader.configJinaReader')).toBeInTheDocument() - expect(screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder')).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.save/i })).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.cancel/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /datasetCreation\.jinaReader\.getApiKeyLinkText/i })).toHaveAttribute('href', 'https://jina.ai/reader/') - }) - }) - - describe('Form Interactions', () => { - it('should update state when API Key field changes', async () => { - const user = userEvent.setup() - // Arrange - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder') - - // Act - await user.type(apiKeyInput, 'jina-test-key') - - // Assert - expect(apiKeyInput).toHaveValue('jina-test-key') - }) - - it('should call onCancel when cancel button is clicked', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - expect(mockOnCancel).toHaveBeenCalled() - }) - }) - - describe('Validation', () => { - it('should show error when saving without API Key', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.fieldRequired:{"field":"API Key"}')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Saving Logic', () => { - it('should save successfully with valid API Key', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder') - - // Act - await user.type(apiKeyInput, 'valid-jina-key') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith({ - category: 'website', - provider: DataSourceProvider.jinaReader, - credentials: { - auth_type: 'bearer', - config: { - api_key: 'valid-jina-key', - }, - }, - }) - }) - await waitFor(() => { - expect(screen.getByText('common.api.success')).toBeInTheDocument() - expect(mockOnSaved).toHaveBeenCalled() - }) - }) - - it('should ignore multiple save clicks while saving is in progress', async () => { - const user = userEvent.setup() - // Arrange - let resolveSave: (value: { result: 'success' }) => void - const savePromise = new Promise<{ result: 'success' }>((resolve) => { - resolveSave = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(savePromise) - render() - await user.type(screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder'), 'test-key') - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - - // Act - await user.click(saveBtn) - await user.click(saveBtn) - - // Assert - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveSave!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1)) - }) - - it('should show encryption info and external link in the modal', async () => { - render() - - // Verify PKCS1_OAEP link exists - const pkcsLink = screen.getByText('PKCS1_OAEP') - expect(pkcsLink.closest('a')).toHaveAttribute('href', 'https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html') - - // Verify the Jina Reader external link - const jinaLink = screen.getByRole('link', { name: /datasetCreation\.jinaReader\.getApiKeyLinkText/i }) - expect(jinaLink).toHaveAttribute('target', '_blank') - }) - - it('should return early when save is clicked while already saving (isSaving guard)', async () => { - const user = userEvent.setup() - // Arrange - a save that never resolves so isSaving stays true - let resolveFirst: (value: { result: 'success' }) => void - const neverResolves = new Promise<{ result: 'success' }>((resolve) => { - resolveFirst = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(neverResolves) - render() - - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder') - await user.type(apiKeyInput, 'valid-key') - - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - // First click - starts saving, isSaving becomes true - await user.click(saveBtn) - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Second click using fireEvent bypasses disabled check - hits isSaving guard - const { fireEvent: fe } = await import('@testing-library/react') - fe.click(saveBtn) - // Still only called once because isSaving=true returns early - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveFirst!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalled()) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-watercrawl-modal.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-watercrawl-modal.spec.tsx deleted file mode 100644 index 6c5961be54..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/config-watercrawl-modal.spec.tsx +++ /dev/null @@ -1,204 +0,0 @@ -import type { CommonResponse } from '@/models/common' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { createDataSourceApiKeyBinding } from '@/service/datasets' -import ConfigWatercrawlModal from '../config-watercrawl-modal' - -/** - * ConfigWatercrawlModal Component Tests - * Tests validation, save logic, and basic rendering for the Watercrawl configuration modal. - */ - -vi.mock('@/service/datasets', () => ({ - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('ConfigWatercrawlModal Component', () => { - const mockOnCancel = vi.fn() - const mockOnSaved = vi.fn() - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Initial Rendering', () => { - it('should render the modal with all fields and buttons', () => { - // Act - render() - - // Assert - expect(screen.getByText('datasetCreation.watercrawl.configWatercrawl')).toBeInTheDocument() - expect(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder')).toBeInTheDocument() - expect(screen.getByPlaceholderText('https://app.watercrawl.dev')).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.save/i })).toBeInTheDocument() - expect(screen.getByRole('button', { name: /common\.operation\.cancel/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /datasetCreation\.watercrawl\.getApiKeyLinkText/i })).toHaveAttribute('href', 'https://app.watercrawl.dev/') - }) - }) - - describe('Form Interactions', () => { - it('should update state when input fields change', async () => { - // Arrange - render() - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder') - const baseUrlInput = screen.getByPlaceholderText('https://app.watercrawl.dev') - - // Act - fireEvent.change(apiKeyInput, { target: { value: 'water-key' } }) - fireEvent.change(baseUrlInput, { target: { value: 'https://custom.watercrawl.dev' } }) - - // Assert - expect(apiKeyInput).toHaveValue('water-key') - expect(baseUrlInput).toHaveValue('https://custom.watercrawl.dev') - }) - - it('should call onCancel when cancel button is clicked', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - expect(mockOnCancel).toHaveBeenCalled() - }) - }) - - describe('Validation', () => { - it('should show error when saving without API Key', async () => { - const user = userEvent.setup() - // Arrange - render() - - // Act - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.fieldRequired:{"field":"API Key"}')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - - it('should show error for invalid Base URL format', async () => { - const user = userEvent.setup() - // Arrange - render() - const baseUrlInput = screen.getByPlaceholderText('https://app.watercrawl.dev') - - // Act - await user.type(baseUrlInput, 'ftp://invalid-url.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(screen.getByText('common.errorMsg.urlError')).toBeInTheDocument() - }) - expect(createDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Saving Logic', () => { - it('should save successfully with valid API Key and custom URL', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'valid-key') - await user.type(screen.getByPlaceholderText('https://app.watercrawl.dev'), 'http://my-watercrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith({ - category: 'website', - provider: 'watercrawl', - credentials: { - auth_type: 'x-api-key', - config: { - api_key: 'valid-key', - base_url: 'http://my-watercrawl.com', - }, - }, - }) - }) - await waitFor(() => { - expect(screen.getByText('common.api.success')).toBeInTheDocument() - expect(mockOnSaved).toHaveBeenCalled() - }) - }) - - it('should use default Base URL if none is provided during save', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'test-api-key') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://app.watercrawl.dev', - }), - }), - })) - }) - }) - - it('should ignore multiple save clicks while saving is in progress', async () => { - const user = userEvent.setup() - // Arrange - let resolveSave: (value: CommonResponse) => void - const savePromise = new Promise((resolve) => { - resolveSave = resolve - }) - vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(savePromise) - render() - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'test-api-key') - const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i }) - - // Act - await user.click(saveBtn) - await user.click(saveBtn) - - // Assert - expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1) - - // Cleanup - resolveSave!({ result: 'success' }) - await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1)) - }) - - it('should accept base_url starting with https://', async () => { - const user = userEvent.setup() - // Arrange - vi.mocked(createDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' }) - render() - - // Act - await user.type(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), 'test-api-key') - await user.type(screen.getByPlaceholderText('https://app.watercrawl.dev'), 'https://secure-watercrawl.com') - await user.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(createDataSourceApiKeyBinding).toHaveBeenCalledWith(expect.objectContaining({ - credentials: expect.objectContaining({ - config: expect.objectContaining({ - base_url: 'https://secure-watercrawl.com', - }), - }), - })) - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/index.spec.tsx deleted file mode 100644 index 1e95cbd087..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/__tests__/index.spec.tsx +++ /dev/null @@ -1,251 +0,0 @@ -import type { AppContextValue } from '@/context/app-context' -import type { CommonResponse } from '@/models/common' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' - -import { useAppContext } from '@/context/app-context' -import { DataSourceProvider } from '@/models/common' -import { fetchDataSources, removeDataSourceApiKeyBinding } from '@/service/datasets' -import DataSourceWebsite from '../index' - -/** - * DataSourceWebsite Component Tests - * Tests integration of multiple website scraping providers (Firecrawl, WaterCrawl, Jina Reader). - */ - -type DataSourcesResponse = CommonResponse & { - sources: Array<{ id: string, provider: DataSourceProvider }> -} - -// Mock App Context -vi.mock('@/context/app-context', () => ({ - useAppContext: vi.fn(), -})) - -// Mock Service calls -vi.mock('@/service/datasets', () => ({ - fetchDataSources: vi.fn(), - removeDataSourceApiKeyBinding: vi.fn(), - createDataSourceApiKeyBinding: vi.fn(), -})) - -describe('DataSourceWebsite Component', () => { - const mockSources = [ - { id: '1', provider: DataSourceProvider.fireCrawl }, - { id: '2', provider: DataSourceProvider.waterCrawl }, - { id: '3', provider: DataSourceProvider.jinaReader }, - ] - - beforeEach(() => { - vi.clearAllMocks() - vi.mocked(useAppContext).mockReturnValue({ isCurrentWorkspaceManager: true } as unknown as AppContextValue) - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [] } as DataSourcesResponse) - }) - - // Helper to render and wait for initial fetch to complete - const renderAndWait = async (provider: DataSourceProvider) => { - const result = render() - await waitFor(() => expect(fetchDataSources).toHaveBeenCalled()) - return result - } - - describe('Data Initialization', () => { - it('should fetch data sources on mount and reflect configured status', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: mockSources } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.fireCrawl) - - // Assert - expect(screen.getByText('common.dataSource.website.configuredCrawlers')).toBeInTheDocument() - }) - - it('should pass readOnly status based on workspace manager permissions', async () => { - // Arrange - vi.mocked(useAppContext).mockReturnValue({ isCurrentWorkspaceManager: false } as unknown as AppContextValue) - - // Act - await renderAndWait(DataSourceProvider.fireCrawl) - - // Assert - expect(screen.getByText('common.dataSource.configure')).toHaveClass('cursor-default') - }) - }) - - describe('Provider Specific Rendering', () => { - it('should render correct logo and name for Firecrawl', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[0]] } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.fireCrawl) - - // Assert - expect(await screen.findByText('Firecrawl')).toBeInTheDocument() - expect(screen.getByText('🔥')).toBeInTheDocument() - }) - - it('should render correct logo and name for WaterCrawl', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[1]] } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.waterCrawl) - - // Assert - const elements = await screen.findAllByText('WaterCrawl') - expect(elements.length).toBeGreaterThanOrEqual(1) - }) - - it('should render correct logo and name for Jina Reader', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[2]] } as DataSourcesResponse) - - // Act - await renderAndWait(DataSourceProvider.jinaReader) - - // Assert - const elements = await screen.findAllByText('Jina Reader') - expect(elements.length).toBeGreaterThanOrEqual(1) - }) - }) - - describe('Modal Interactions', () => { - it('should manage opening and closing of configuration modals', async () => { - // Arrange - await renderAndWait(DataSourceProvider.fireCrawl) - - // Act (Open) - fireEvent.click(screen.getByText('common.dataSource.configure')) - // Assert - expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument() - - // Act (Cancel) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - // Assert - expect(screen.queryByText('datasetCreation.firecrawl.configFirecrawl')).not.toBeInTheDocument() - }) - - it('should re-fetch sources after saving configuration (Watercrawl)', async () => { - // Arrange - await renderAndWait(DataSourceProvider.waterCrawl) - fireEvent.click(screen.getByText('common.dataSource.configure')) - vi.mocked(fetchDataSources).mockClear() - - // Act - fireEvent.change(screen.getByPlaceholderText('datasetCreation.watercrawl.apiKeyPlaceholder'), { target: { value: 'test-key' } }) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(fetchDataSources).toHaveBeenCalled() - expect(screen.queryByText('datasetCreation.watercrawl.configWatercrawl')).not.toBeInTheDocument() - }) - }) - - it('should re-fetch sources after saving configuration (Jina Reader)', async () => { - // Arrange - await renderAndWait(DataSourceProvider.jinaReader) - fireEvent.click(screen.getByText('common.dataSource.configure')) - vi.mocked(fetchDataSources).mockClear() - - // Act - fireEvent.change(screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder'), { target: { value: 'test-key' } }) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(fetchDataSources).toHaveBeenCalled() - expect(screen.queryByText('datasetCreation.jinaReader.configJinaReader')).not.toBeInTheDocument() - }) - }) - }) - - describe('Management Actions', () => { - it('should handle successful data source removal with toast notification', async () => { - // Arrange - vi.mocked(fetchDataSources).mockResolvedValue({ result: 'success', sources: [mockSources[0]] } as DataSourcesResponse) - vi.mocked(removeDataSourceApiKeyBinding).mockResolvedValue({ result: 'success' } as CommonResponse) - await renderAndWait(DataSourceProvider.fireCrawl) - await waitFor(() => expect(screen.getByText('common.dataSource.website.configuredCrawlers')).toBeInTheDocument()) - - // Act - const removeBtn = screen.getByText('Firecrawl').parentElement?.querySelector('svg')?.parentElement - if (removeBtn) - fireEvent.click(removeBtn) - - // Assert - await waitFor(() => { - expect(removeDataSourceApiKeyBinding).toHaveBeenCalledWith('1') - expect(screen.getByText('common.api.remove')).toBeInTheDocument() - }) - expect(screen.queryByText('common.dataSource.website.configuredCrawlers')).not.toBeInTheDocument() - }) - - it('should skip removal API call if no data source ID is present', async () => { - // Arrange - await renderAndWait(DataSourceProvider.fireCrawl) - - // Act - const removeBtn = screen.queryByText('Firecrawl')?.parentElement?.querySelector('svg')?.parentElement - if (removeBtn) - fireEvent.click(removeBtn) - - // Assert - expect(removeDataSourceApiKeyBinding).not.toHaveBeenCalled() - }) - }) - - describe('Firecrawl Save Flow', () => { - it('should re-fetch sources after saving Firecrawl configuration', async () => { - // Arrange - await renderAndWait(DataSourceProvider.fireCrawl) - fireEvent.click(screen.getByText('common.dataSource.configure')) - expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument() - vi.mocked(fetchDataSources).mockClear() - - // Act - fill in required API key field and save - const apiKeyInput = screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder') - fireEvent.change(apiKeyInput, { target: { value: 'test-key' } }) - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i })) - - // Assert - await waitFor(() => { - expect(fetchDataSources).toHaveBeenCalled() - expect(screen.queryByText('datasetCreation.firecrawl.configFirecrawl')).not.toBeInTheDocument() - }) - }) - }) - - describe('Cancel Flow', () => { - it('should close watercrawl modal when cancel is clicked', async () => { - // Arrange - await renderAndWait(DataSourceProvider.waterCrawl) - fireEvent.click(screen.getByText('common.dataSource.configure')) - expect(screen.getByText('datasetCreation.watercrawl.configWatercrawl')).toBeInTheDocument() - - // Act - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - modal closed - await waitFor(() => { - expect(screen.queryByText('datasetCreation.watercrawl.configWatercrawl')).not.toBeInTheDocument() - }) - }) - - it('should close jina reader modal when cancel is clicked', async () => { - // Arrange - await renderAndWait(DataSourceProvider.jinaReader) - fireEvent.click(screen.getByText('common.dataSource.configure')) - expect(screen.getByText('datasetCreation.jinaReader.configJinaReader')).toBeInTheDocument() - - // Act - fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i })) - - // Assert - modal closed - await waitFor(() => { - expect(screen.queryByText('datasetCreation.jinaReader.configJinaReader')).not.toBeInTheDocument() - }) - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx deleted file mode 100644 index d7f15236a7..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx +++ /dev/null @@ -1,165 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { FirecrawlConfig } from '@/models/common' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' -import Field from '@/app/components/datasets/create/website/base/field' -import { createDataSourceApiKeyBinding } from '@/service/datasets' - -type Props = { - onCancel: () => void - onSaved: () => void -} - -const I18N_PREFIX = 'firecrawl' - -const DEFAULT_BASE_URL = 'https://api.firecrawl.dev' - -const ConfigFirecrawlModal: FC = ({ - onCancel, - onSaved, -}) => { - const { t } = useTranslation() - const [isSaving, setIsSaving] = useState(false) - const [config, setConfig] = useState({ - api_key: '', - base_url: '', - }) - - const handleConfigChange = useCallback((key: string) => { - return (value: string | number) => { - setConfig(prev => ({ ...prev, [key]: value as string })) - } - }, []) - - const handleSave = useCallback(async () => { - if (isSaving) - return - let errorMsg = '' - if (config.base_url && !((config.base_url.startsWith('http://') || config.base_url.startsWith('https://')))) - errorMsg = t('errorMsg.urlError', { ns: 'common' }) - if (!errorMsg) { - if (!config.api_key) { - errorMsg = t('errorMsg.fieldRequired', { - ns: 'common', - field: 'API Key', - }) - } - } - - if (errorMsg) { - Toast.notify({ - type: 'error', - message: errorMsg, - }) - return - } - const postData = { - category: 'website', - provider: 'firecrawl', - credentials: { - auth_type: 'bearer', - config: { - api_key: config.api_key, - base_url: config.base_url || DEFAULT_BASE_URL, - }, - }, - } - try { - setIsSaving(true) - await createDataSourceApiKeyBinding(postData) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - } - finally { - setIsSaving(false) - } - - onSaved() - }, [config.api_key, config.base_url, onSaved, t, isSaving]) - - return ( - - -
-
-
-
-
{t(`${I18N_PREFIX}.configFirecrawl`, { ns: 'datasetCreation' })}
-
- -
- - -
-
- - {t(`${I18N_PREFIX}.getApiKeyLinkText`, { ns: 'datasetCreation' })} - - -
- - -
- -
-
-
-
- - {t('modelProvider.encrypted.front', { ns: 'common' })} - - PKCS1_OAEP - - {t('modelProvider.encrypted.back', { ns: 'common' })} -
-
-
-
-
-
- ) -} -export default React.memo(ConfigFirecrawlModal) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx deleted file mode 100644 index 2374ae6174..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx +++ /dev/null @@ -1,144 +0,0 @@ -'use client' -import type { FC } from 'react' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' -import Field from '@/app/components/datasets/create/website/base/field' -import { DataSourceProvider } from '@/models/common' -import { createDataSourceApiKeyBinding } from '@/service/datasets' - -type Props = { - onCancel: () => void - onSaved: () => void -} - -const I18N_PREFIX = 'jinaReader' - -const ConfigJinaReaderModal: FC = ({ - onCancel, - onSaved, -}) => { - const { t } = useTranslation() - const [isSaving, setIsSaving] = useState(false) - const [apiKey, setApiKey] = useState('') - - const handleSave = useCallback(async () => { - if (isSaving) - return - let errorMsg = '' - if (!errorMsg) { - if (!apiKey) { - errorMsg = t('errorMsg.fieldRequired', { - ns: 'common', - field: 'API Key', - }) - } - } - - if (errorMsg) { - Toast.notify({ - type: 'error', - message: errorMsg, - }) - return - } - const postData = { - category: 'website', - provider: DataSourceProvider.jinaReader, - credentials: { - auth_type: 'bearer', - config: { - api_key: apiKey, - }, - }, - } - try { - setIsSaving(true) - await createDataSourceApiKeyBinding(postData) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - } - finally { - setIsSaving(false) - } - - onSaved() - }, [apiKey, onSaved, t, isSaving]) - - return ( - - -
-
-
-
-
{t(`${I18N_PREFIX}.configJinaReader`, { ns: 'datasetCreation' })}
-
- -
- setApiKey(value as string)} - placeholder={t(`${I18N_PREFIX}.apiKeyPlaceholder`, { ns: 'datasetCreation' })!} - /> -
-
- - {t(`${I18N_PREFIX}.getApiKeyLinkText`, { ns: 'datasetCreation' })} - - -
- - -
- -
-
-
-
- - {t('modelProvider.encrypted.front', { ns: 'common' })} - - PKCS1_OAEP - - {t('modelProvider.encrypted.back', { ns: 'common' })} -
-
-
-
-
-
- ) -} -export default React.memo(ConfigJinaReaderModal) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx deleted file mode 100644 index a9399f25cd..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx +++ /dev/null @@ -1,165 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { WatercrawlConfig } from '@/models/common' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' -import Field from '@/app/components/datasets/create/website/base/field' -import { createDataSourceApiKeyBinding } from '@/service/datasets' - -type Props = { - onCancel: () => void - onSaved: () => void -} - -const I18N_PREFIX = 'watercrawl' - -const DEFAULT_BASE_URL = 'https://app.watercrawl.dev' - -const ConfigWatercrawlModal: FC = ({ - onCancel, - onSaved, -}) => { - const { t } = useTranslation() - const [isSaving, setIsSaving] = useState(false) - const [config, setConfig] = useState({ - api_key: '', - base_url: '', - }) - - const handleConfigChange = useCallback((key: string) => { - return (value: string | number) => { - setConfig(prev => ({ ...prev, [key]: value as string })) - } - }, []) - - const handleSave = useCallback(async () => { - if (isSaving) - return - let errorMsg = '' - if (config.base_url && !((config.base_url.startsWith('http://') || config.base_url.startsWith('https://')))) - errorMsg = t('errorMsg.urlError', { ns: 'common' }) - if (!errorMsg) { - if (!config.api_key) { - errorMsg = t('errorMsg.fieldRequired', { - ns: 'common', - field: 'API Key', - }) - } - } - - if (errorMsg) { - Toast.notify({ - type: 'error', - message: errorMsg, - }) - return - } - const postData = { - category: 'website', - provider: 'watercrawl', - credentials: { - auth_type: 'x-api-key', - config: { - api_key: config.api_key, - base_url: config.base_url || DEFAULT_BASE_URL, - }, - }, - } - try { - setIsSaving(true) - await createDataSourceApiKeyBinding(postData) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) - } - finally { - setIsSaving(false) - } - - onSaved() - }, [config.api_key, config.base_url, onSaved, t, isSaving]) - - return ( - - -
-
-
-
-
{t(`${I18N_PREFIX}.configWatercrawl`, { ns: 'datasetCreation' })}
-
- -
- - -
-
- - {t(`${I18N_PREFIX}.getApiKeyLinkText`, { ns: 'datasetCreation' })} - - -
- - -
- -
-
-
-
- - {t('modelProvider.encrypted.front', { ns: 'common' })} - - PKCS1_OAEP - - {t('modelProvider.encrypted.back', { ns: 'common' })} -
-
-
-
-
-
- ) -} -export default React.memo(ConfigWatercrawlModal) diff --git a/web/app/components/header/account-setting/data-source-page/data-source-website/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-website/index.tsx deleted file mode 100644 index 22bfb4950e..0000000000 --- a/web/app/components/header/account-setting/data-source-page/data-source-website/index.tsx +++ /dev/null @@ -1,137 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { DataSourceItem } from '@/models/common' -import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' -import s from '@/app/components/datasets/create/website/index.module.css' -import { useAppContext } from '@/context/app-context' -import { DataSourceProvider } from '@/models/common' -import { fetchDataSources, removeDataSourceApiKeyBinding } from '@/service/datasets' -import { cn } from '@/utils/classnames' -import Panel from '../panel' - -import { DataSourceType } from '../panel/types' -import ConfigFirecrawlModal from './config-firecrawl-modal' -import ConfigJinaReaderModal from './config-jina-reader-modal' -import ConfigWatercrawlModal from './config-watercrawl-modal' - -type Props = { - provider: DataSourceProvider -} - -const DataSourceWebsite: FC = ({ provider }) => { - const { t } = useTranslation() - const { isCurrentWorkspaceManager } = useAppContext() - const [sources, setSources] = useState([]) - const checkSetApiKey = useCallback(async () => { - const res = await fetchDataSources() as any - const list = res.sources - setSources(list) - }, []) - - useEffect(() => { - checkSetApiKey() - }, []) - - const [configTarget, setConfigTarget] = useState(null) - const showConfig = useCallback((provider: DataSourceProvider) => { - setConfigTarget(provider) - }, [setConfigTarget]) - - const hideConfig = useCallback(() => { - setConfigTarget(null) - }, [setConfigTarget]) - - const handleAdded = useCallback(() => { - checkSetApiKey() - hideConfig() - }, [checkSetApiKey, hideConfig]) - - const getIdByProvider = (provider: DataSourceProvider): string | undefined => { - const source = sources.find(item => item.provider === provider) - return source?.id - } - - const getProviderName = (provider: DataSourceProvider): string => { - if (provider === DataSourceProvider.fireCrawl) - return 'Firecrawl' - - if (provider === DataSourceProvider.waterCrawl) - return 'WaterCrawl' - - return 'Jina Reader' - } - - const handleRemove = useCallback((provider: DataSourceProvider) => { - return async () => { - const dataSourceId = getIdByProvider(provider) - if (dataSourceId) { - await removeDataSourceApiKeyBinding(dataSourceId) - setSources(sources.filter(item => item.provider !== provider)) - Toast.notify({ - type: 'success', - message: t('api.remove', { ns: 'common' }), - }) - } - } - }, [sources, t]) - - return ( - <> - item.provider === provider) !== undefined} - onConfigure={() => showConfig(provider)} - readOnly={!isCurrentWorkspaceManager} - configuredList={sources.filter(item => item.provider === provider).map(item => ({ - id: item.id, - logo: ({ className }: { className: string }) => { - if (item.provider === DataSourceProvider.fireCrawl) { - return ( -
- 🔥 -
- ) - } - - if (item.provider === DataSourceProvider.waterCrawl) { - return ( -
- -
- ) - } - return ( -
- -
- ) - }, - name: getProviderName(item.provider), - isActive: true, - }))} - onRemove={handleRemove(provider)} - /> - {configTarget === DataSourceProvider.fireCrawl && ( - - )} - {configTarget === DataSourceProvider.waterCrawl && ( - - )} - {configTarget === DataSourceProvider.jinaReader && ( - - )} - - - ) -} -export default React.memo(DataSourceWebsite) diff --git a/web/app/components/header/account-setting/data-source-page/panel/__tests__/config-item.spec.tsx b/web/app/components/header/account-setting/data-source-page/panel/__tests__/config-item.spec.tsx deleted file mode 100644 index 4ad49a8f8f..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/__tests__/config-item.spec.tsx +++ /dev/null @@ -1,213 +0,0 @@ -import type { ConfigItemType } from '../config-item' -import { fireEvent, render, screen } from '@testing-library/react' -import ConfigItem from '../config-item' -import { DataSourceType } from '../types' - -/** - * ConfigItem Component Tests - * Tests rendering of individual configuration items for Notion and Website data sources. - */ - -// Mock Operate component to isolate ConfigItem unit tests. -vi.mock('../../data-source-notion/operate', () => ({ - default: ({ onAuthAgain, payload }: { onAuthAgain: () => void, payload: { id: string, total: number } }) => ( -
- - {JSON.stringify(payload)} -
- ), -})) - -describe('ConfigItem Component', () => { - const mockOnRemove = vi.fn() - const mockOnChangeAuthorizedPage = vi.fn() - const MockLogo = (props: React.SVGProps) => - - const baseNotionPayload: ConfigItemType = { - id: 'notion-1', - logo: MockLogo, - name: 'Notion Workspace', - isActive: true, - notionConfig: { total: 5 }, - } - - const baseWebsitePayload: ConfigItemType = { - id: 'website-1', - logo: MockLogo, - name: 'My Website', - isActive: true, - } - - afterEach(() => { - vi.clearAllMocks() - }) - - describe('Notion Configuration', () => { - it('should render active Notion config item with connected status and operator', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByTestId('mock-logo')).toBeInTheDocument() - expect(screen.getByText('Notion Workspace')).toBeInTheDocument() - const statusText = screen.getByText('common.dataSource.notion.connected') - expect(statusText).toHaveClass('text-util-colors-green-green-600') - expect(screen.getByTestId('operate-payload')).toHaveTextContent(JSON.stringify({ id: 'notion-1', total: 5 })) - }) - - it('should render inactive Notion config item with disconnected status', () => { - // Arrange - const inactivePayload = { ...baseNotionPayload, isActive: false } - - // Act - render( - , - ) - - // Assert - const statusText = screen.getByText('common.dataSource.notion.disconnected') - expect(statusText).toHaveClass('text-util-colors-warning-warning-600') - }) - - it('should handle auth action through the Operate component', () => { - // Arrange - render( - , - ) - - // Act - fireEvent.click(screen.getByTestId('operate-auth-btn')) - - // Assert - expect(mockOnChangeAuthorizedPage).toHaveBeenCalled() - }) - - it('should fallback to 0 total if notionConfig is missing', () => { - // Arrange - const payloadNoConfig = { ...baseNotionPayload, notionConfig: undefined } - - // Act - render( - , - ) - - // Assert - expect(screen.getByTestId('operate-payload')).toHaveTextContent(JSON.stringify({ id: 'notion-1', total: 0 })) - }) - - it('should handle missing notionActions safely without crashing', () => { - // Arrange - render( - , - ) - - // Act & Assert - expect(() => fireEvent.click(screen.getByTestId('operate-auth-btn'))).not.toThrow() - }) - }) - - describe('Website Configuration', () => { - it('should render active Website config item and hide operator', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByText('common.dataSource.website.active')).toBeInTheDocument() - expect(screen.queryByTestId('mock-operate')).not.toBeInTheDocument() - }) - - it('should render inactive Website config item', () => { - // Arrange - const inactivePayload = { ...baseWebsitePayload, isActive: false } - - // Act - render( - , - ) - - // Assert - const statusText = screen.getByText('common.dataSource.website.inactive') - expect(statusText).toHaveClass('text-util-colors-warning-warning-600') - }) - - it('should show remove button and trigger onRemove when clicked (not read-only)', () => { - // Arrange - const { container } = render( - , - ) - - // Note: This selector is brittle but necessary since the delete button lacks - // accessible attributes (data-testid, aria-label). Ideally, the component should - // be updated to include proper accessibility attributes. - const deleteBtn = container.querySelector('div[class*="cursor-pointer"]') as HTMLElement - - // Act - fireEvent.click(deleteBtn) - - // Assert - expect(mockOnRemove).toHaveBeenCalled() - }) - - it('should hide remove button in read-only mode', () => { - // Arrange - const { container } = render( - , - ) - - // Assert - const deleteBtn = container.querySelector('div[class*="cursor-pointer"]') - expect(deleteBtn).not.toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/panel/__tests__/index.spec.tsx b/web/app/components/header/account-setting/data-source-page/panel/__tests__/index.spec.tsx deleted file mode 100644 index d83cdb5360..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/__tests__/index.spec.tsx +++ /dev/null @@ -1,226 +0,0 @@ -import type { ConfigItemType } from '../config-item' -import { fireEvent, render, screen } from '@testing-library/react' -import { DataSourceProvider } from '@/models/common' -import Panel from '../index' -import { DataSourceType } from '../types' - -/** - * Panel Component Tests - * Tests layout, conditional rendering, and interactions for data source panels (Notion and Website). - */ - -vi.mock('../../data-source-notion/operate', () => ({ - default: () =>
, -})) - -describe('Panel Component', () => { - const onConfigure = vi.fn() - const onRemove = vi.fn() - const mockConfiguredList: ConfigItemType[] = [ - { id: '1', name: 'Item 1', isActive: true, logo: () => null }, - { id: '2', name: 'Item 2', isActive: false, logo: () => null }, - ] - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Notion Panel Rendering', () => { - it('should render Notion panel when not configured and isSupportList is true', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByText('common.dataSource.notion.title')).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.description')).toBeInTheDocument() - const connectBtn = screen.getByText('common.dataSource.connect') - expect(connectBtn).toBeInTheDocument() - - // Act - fireEvent.click(connectBtn) - // Assert - expect(onConfigure).toHaveBeenCalled() - }) - - it('should render Notion panel in readOnly mode when not configured', () => { - // Act - render( - , - ) - - // Assert - const connectBtn = screen.getByText('common.dataSource.connect') - expect(connectBtn).toHaveClass('cursor-default opacity-50 grayscale') - }) - - it('should render Notion panel when configured with list of items', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByRole('button', { name: 'common.dataSource.configure' })).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.connectedWorkspace')).toBeInTheDocument() - expect(screen.getByText('Item 1')).toBeInTheDocument() - expect(screen.getByText('Item 2')).toBeInTheDocument() - }) - - it('should hide connect button for Notion if isSupportList is false', () => { - // Act - render( - , - ) - - // Assert - expect(screen.queryByText('common.dataSource.connect')).not.toBeInTheDocument() - }) - - it('should disable Notion configure button in readOnly mode (configured state)', () => { - // Act - render( - , - ) - - // Assert - const btn = screen.getByRole('button', { name: 'common.dataSource.configure' }) - expect(btn).toBeDisabled() - }) - }) - - describe('Website Panel Rendering', () => { - it('should show correct provider names and handle configuration when not configured', () => { - // Arrange - const { rerender } = render( - , - ) - - // Assert Firecrawl - expect(screen.getByText('🔥 Firecrawl')).toBeInTheDocument() - - // Rerender for WaterCrawl - rerender( - , - ) - expect(screen.getByText('WaterCrawl')).toBeInTheDocument() - - // Rerender for Jina Reader - rerender( - , - ) - expect(screen.getByText('Jina Reader')).toBeInTheDocument() - - // Act - const configBtn = screen.getByText('common.dataSource.configure') - fireEvent.click(configBtn) - // Assert - expect(onConfigure).toHaveBeenCalled() - }) - - it('should handle readOnly mode for Website configuration button', () => { - // Act - render( - , - ) - - // Assert - const configBtn = screen.getByText('common.dataSource.configure') - expect(configBtn).toHaveClass('cursor-default opacity-50 grayscale') - - // Act - fireEvent.click(configBtn) - // Assert - expect(onConfigure).not.toHaveBeenCalled() - }) - - it('should render Website panel correctly when configured with crawlers', () => { - // Act - render( - , - ) - - // Assert - expect(screen.getByText('common.dataSource.website.configuredCrawlers')).toBeInTheDocument() - expect(screen.getByText('Item 1')).toBeInTheDocument() - expect(screen.getByText('Item 2')).toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx b/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx deleted file mode 100644 index f62c5e147d..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx +++ /dev/null @@ -1,85 +0,0 @@ -'use client' -import type { FC } from 'react' -import { - RiDeleteBinLine, -} from '@remixicon/react' -import { noop } from 'es-toolkit/function' -import * as React from 'react' -import { useTranslation } from 'react-i18next' -import { cn } from '@/utils/classnames' -import Indicator from '../../../indicator' -import Operate from '../data-source-notion/operate' -import s from './style.module.css' -import { DataSourceType } from './types' - -export type ConfigItemType = { - id: string - logo: any - name: string - isActive: boolean - notionConfig?: { - total: number - } -} - -type Props = { - type: DataSourceType - payload: ConfigItemType - onRemove: () => void - notionActions?: { - onChangeAuthorizedPage: () => void - } - readOnly: boolean -} - -const ConfigItem: FC = ({ - type, - payload, - onRemove, - notionActions, - readOnly, -}) => { - const { t } = useTranslation() - const isNotion = type === DataSourceType.notion - const isWebsite = type === DataSourceType.website - const onChangeAuthorizedPage = notionActions?.onChangeAuthorizedPage || noop - - return ( -
- -
{payload.name}
- { - payload.isActive - ? - : - } -
- { - payload.isActive - ? t(isNotion ? 'dataSource.notion.connected' : 'dataSource.website.active', { ns: 'common' }) - : t(isNotion ? 'dataSource.notion.disconnected' : 'dataSource.website.inactive', { ns: 'common' }) - } -
-
- {isNotion && ( - - )} - - { - isWebsite && !readOnly && ( -
- -
- ) - } - -
- ) -} -export default React.memo(ConfigItem) diff --git a/web/app/components/header/account-setting/data-source-page/panel/index.tsx b/web/app/components/header/account-setting/data-source-page/panel/index.tsx deleted file mode 100644 index 0909603ae8..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/index.tsx +++ /dev/null @@ -1,151 +0,0 @@ -'use client' -import type { FC } from 'react' -import type { ConfigItemType } from './config-item' -import { RiAddLine } from '@remixicon/react' -import * as React from 'react' -import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' - -import { DataSourceProvider } from '@/models/common' -import { cn } from '@/utils/classnames' -import ConfigItem from './config-item' -import s from './style.module.css' -import { DataSourceType } from './types' - -type Props = { - type: DataSourceType - provider?: DataSourceProvider - isConfigured: boolean - onConfigure: () => void - readOnly: boolean - isSupportList?: boolean - configuredList: ConfigItemType[] - onRemove: () => void - notionActions?: { - onChangeAuthorizedPage: () => void - } -} - -const Panel: FC = ({ - type, - provider, - isConfigured, - onConfigure, - readOnly, - configuredList, - isSupportList, - onRemove, - notionActions, -}) => { - const { t } = useTranslation() - const isNotion = type === DataSourceType.notion - const isWebsite = type === DataSourceType.website - - const getProviderName = (): string => { - if (provider === DataSourceProvider.fireCrawl) - return '🔥 Firecrawl' - if (provider === DataSourceProvider.waterCrawl) - return 'WaterCrawl' - return 'Jina Reader' - } - - return ( -
-
-
-
-
-
{t(`dataSource.${type}.title`, { ns: 'common' })}
- {isWebsite && ( -
- {t('dataSource.website.with', { ns: 'common' })} - {' '} - {getProviderName()} -
- )} -
- { - !isConfigured && ( -
- {t(`dataSource.${type}.description`, { ns: 'common' })} -
- ) - } -
- {isNotion && ( - <> - { - isConfigured - ? ( - - ) - : ( - <> - {isSupportList && ( -
- - {t('dataSource.connect', { ns: 'common' })} -
- )} - - ) - } - - )} - - {isWebsite && !isConfigured && ( -
- {t('dataSource.configure', { ns: 'common' })} -
- )} - -
- { - isConfigured && ( - <> -
-
- {isNotion ? t('dataSource.notion.connectedWorkspace', { ns: 'common' }) : t('dataSource.website.configuredCrawlers', { ns: 'common' })} -
-
-
-
- { - configuredList.map(item => ( - - )) - } -
- - ) - } -
- ) -} -export default React.memo(Panel) diff --git a/web/app/components/header/account-setting/data-source-page/panel/style.module.css b/web/app/components/header/account-setting/data-source-page/panel/style.module.css deleted file mode 100644 index ac9be02205..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/style.module.css +++ /dev/null @@ -1,17 +0,0 @@ -.notion-icon { - background: #ffffff url(../../../assets/notion.svg) center center no-repeat; - background-size: 20px 20px; -} - -.website-icon { - background: #ffffff url(../../../../datasets/create/assets/web.svg) center center no-repeat; - background-size: 20px 20px; -} - -.workspace-item { - box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); -} - -.workspace-item:last-of-type { - margin-bottom: 0; -} diff --git a/web/app/components/header/account-setting/data-source-page/panel/types.ts b/web/app/components/header/account-setting/data-source-page/panel/types.ts deleted file mode 100644 index 345bc10f81..0000000000 --- a/web/app/components/header/account-setting/data-source-page/panel/types.ts +++ /dev/null @@ -1,4 +0,0 @@ -export enum DataSourceType { - notion = 'notion', - website = 'website', -} diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index 7e77af2e5f..bfceaeb059 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -1,8 +1,9 @@ 'use client' import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' -import { useCallback, useEffect, useRef, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import SearchInput from '@/app/components/base/search-input' +import { ScrollArea } from '@/app/components/base/ui/scroll-area' import BillingPage from '@/app/components/billing/billing-page' import CustomPage from '@/app/components/custom/custom-page' import { @@ -129,20 +130,6 @@ export default function AccountSetting({ ], }, ] - const scrollRef = useRef(null) - const [scrolled, setScrolled] = useState(false) - useEffect(() => { - const targetElement = scrollRef.current - const scrollHandle = (e: Event) => { - const userScrolled = (e.target as HTMLDivElement).scrollTop > 0 - setScrolled(userScrolled) - } - targetElement?.addEventListener('scroll', scrollHandle) - return () => { - targetElement?.removeEventListener('scroll', scrollHandle) - } - }, []) - const activeItem = [...menuItems[0].items, ...menuItems[1].items].find(item => item.key === activeMenu) const [searchValue, setSearchValue] = useState('') @@ -201,7 +188,7 @@ export default function AccountSetting({ }
-
+
-
-
+ +
{activeItem?.name} {activeItem?.description && ( @@ -241,7 +234,7 @@ export default function AccountSetting({ {activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && } {activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && }
-
+
diff --git a/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx b/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx index d2aeca1b6c..7de1fbeccb 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/__tests__/index.spec.tsx @@ -2,11 +2,15 @@ import type { InvitationResponse } from '@/models/common' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { useProviderContextSelector } from '@/context/provider-context' import { inviteMember } from '@/service/common' import InviteModal from '../index' +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + vi.mock('@/context/provider-context', () => ({ useProviderContextSelector: vi.fn(), useProviderContext: vi.fn(() => ({ @@ -14,6 +18,11 @@ vi.mock('@/context/provider-context', () => ({ })), })) vi.mock('@/service/common') +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + }, +})) vi.mock('@/context/i18n', () => ({ useLocale: () => 'en-US', })) @@ -37,7 +46,6 @@ describe('InviteModal', () => { const mockOnCancel = vi.fn() const mockOnSend = vi.fn() const mockRefreshLicenseLimit = vi.fn() - const mockNotify = vi.fn() beforeEach(() => { vi.clearAllMocks() @@ -49,10 +57,11 @@ describe('InviteModal', () => { }) const renderModal = (isEmailSetup = true) => render( - - - , + , ) + const fillEmails = (value: string) => { + fireEvent.change(screen.getByTestId('mock-email-input'), { target: { value } }) + } it('should render invite modal content', async () => { renderModal() @@ -68,12 +77,8 @@ describe('InviteModal', () => { }) it('should enable send button after entering an email', async () => { - const user = userEvent.setup() - renderModal() - - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') expect(screen.getByRole('button', { name: /members\.sendInvite/i })).toBeEnabled() }) @@ -84,7 +89,7 @@ describe('InviteModal', () => { renderModal() - await user.type(screen.getByTestId('mock-email-input'), 'user@example.com') + fillEmails('user@example.com') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) await waitFor(() => { @@ -103,8 +108,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) await waitFor(() => { @@ -116,8 +120,6 @@ describe('InviteModal', () => { }) it('should keep send button disabled when license limit is exceeded', async () => { - const user = userEvent.setup() - vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({ licenseLimit: { workspace_members: { size: 10, limit: 10 } }, refreshLicenseLimit: mockRefreshLicenseLimit, @@ -125,8 +127,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') expect(screen.getByRole('button', { name: /members\.sendInvite/i })).toBeDisabled() }) @@ -144,15 +145,11 @@ describe('InviteModal', () => { const user = userEvent.setup() renderModal() - const input = screen.getByTestId('mock-email-input') // Use an email that passes basic validation but fails our strict regex (needs 2+ char TLD) - await user.type(input, 'invalid@email.c') + fillEmails('invalid@email.c') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) - expect(mockNotify).toHaveBeenCalledWith({ - type: 'error', - message: 'common.members.emailInvalid', - }) + expect(toast.error).toHaveBeenCalledWith('common.members.emailInvalid') expect(inviteMember).not.toHaveBeenCalled() }) @@ -160,8 +157,7 @@ describe('InviteModal', () => { const user = userEvent.setup() renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') expect(screen.getByText('user@example.com')).toBeInTheDocument() @@ -203,7 +199,7 @@ describe('InviteModal', () => { renderModal() - await user.type(screen.getByTestId('mock-email-input'), 'user@example.com') + fillEmails('user@example.com') await user.click(screen.getByRole('button', { name: /members\.sendInvite/i })) await waitFor(() => { @@ -214,8 +210,6 @@ describe('InviteModal', () => { }) it('should show destructive text color when used size exceeds limit', async () => { - const user = userEvent.setup() - vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({ licenseLimit: { workspace_members: { size: 10, limit: 10 } }, refreshLicenseLimit: mockRefreshLicenseLimit, @@ -223,8 +217,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') // usedSize = 10 + 1 = 11 > limit 10 → destructive color const counter = screen.getByText('11') @@ -241,8 +234,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i }) @@ -264,8 +256,6 @@ describe('InviteModal', () => { }) it('should show destructive color and disable send button when limit is exactly met with one email', async () => { - const user = userEvent.setup() - // size=10, limit=10 - adding 1 email makes usedSize=11 > limit=10 vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({ licenseLimit: { workspace_members: { size: 10, limit: 10 } }, @@ -274,8 +264,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') // isLimitExceeded=true → button is disabled, cannot submit const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i }) @@ -293,8 +282,7 @@ describe('InviteModal', () => { renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i }) @@ -320,11 +308,9 @@ describe('InviteModal', () => { refreshLicenseLimit: mockRefreshLicenseLimit, } as unknown as Parameters[0])) - const user = userEvent.setup() renderModal() - const input = screen.getByTestId('mock-email-input') - await user.type(input, 'user@example.com') + fillEmails('user@example.com') // isLimited=false → no destructive color const counter = screen.getByText('1') diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.module.css b/web/app/components/header/account-setting/members-page/invite-modal/index.module.css deleted file mode 100644 index fbaa1187bd..0000000000 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.module.css +++ /dev/null @@ -1,12 +0,0 @@ -.modal { - padding: 24px 32px !important; - width: 400px !important; -} - -.emailsInput { - background-color: rgb(243 244 246 / var(--tw-bg-opacity)) !important; -} - -.emailBackground { - background-color: white !important; -} diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index 8e4e47e0b8..9b4e9fccdc 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -2,20 +2,17 @@ import type { RoleKey } from './role-selector' import type { InvitationResult } from '@/models/common' import { useBoolean } from 'ahooks' -import { noop } from 'es-toolkit/function' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { ReactMultiEmail } from 'react-multi-email' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' -import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@/app/components/base/ui/dialog' +import { toast } from '@/app/components/base/ui/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useProviderContextSelector } from '@/context/provider-context' import { inviteMember } from '@/service/common' import { cn } from '@/utils/classnames' -import s from './index.module.css' import RoleSelector from './role-selector' import 'react-multi-email/dist/style.css' @@ -34,7 +31,6 @@ const InviteModal = ({ const licenseLimit = useProviderContextSelector(s => s.licenseLimit) const refreshLicenseLimit = useProviderContextSelector(s => s.refreshLicenseLimit) const [emails, setEmails] = useState([]) - const { notify } = useContext(ToastContext) const [isLimited, setIsLimited] = useState(false) const [isLimitExceeded, setIsLimitExceeded] = useState(false) const [usedSize, setUsedSize] = useState(licenseLimit.workspace_members.size ?? 0) @@ -74,21 +70,28 @@ const InviteModal = ({ catch { } } else { - notify({ type: 'error', message: t('members.emailInvalid', { ns: 'common' }) }) + toast.error(t('members.emailInvalid', { ns: 'common' })) } setIsSubmitted() - }, [isLimitExceeded, emails, role, locale, onCancel, onSend, notify, t, isSubmitting, refreshLicenseLimit, setIsSubmitted, setIsSubmitting]) + }, [isLimitExceeded, emails, role, locale, onCancel, onSend, t, isSubmitting, refreshLicenseLimit, setIsSubmitted, setIsSubmitting]) return ( -
- -
-
{t('members.inviteTeamMember', { ns: 'common' })}
-
+ { + if (!open) + onCancel() + }} + > + + +
+ + {t('members.inviteTeamMember', { ns: 'common' })} +
{t('members.inviteTeamMemberTip', { ns: 'common' })}
{!isEmailSetup && ( @@ -152,8 +155,8 @@ const InviteModal = ({ {t('members.sendInvite', { ns: 'common' })}
- -
+ + ) } diff --git a/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx b/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx index e258884b0f..6383b203d9 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/role-selector.tsx @@ -1,11 +1,10 @@ import * as React from 'react' -import { useState } from 'react' import { useTranslation } from 'react-i18next' import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' import { useProviderContext } from '@/context/provider-context' import { cn } from '@/utils/classnames' @@ -25,115 +24,111 @@ export type RoleSelectorProps = { const RoleSelector = ({ value, onChange }: RoleSelectorProps) => { const { t } = useTranslation() - const [open, setOpen] = useState(false) const { datasetOperatorEnabled } = useProviderContext() + const [open, setOpen] = React.useState(false) return ( - -
- setOpen(v => !v)} - className="block" - > + +
{t('members.invitedAsRole', { ns: 'common', role: t(roleI18nKeyMap[value], { ns: 'common' }) })}
+
+ + +
{ + onChange('normal') + setOpen(false) + }} > -
{t('members.invitedAsRole', { ns: 'common', role: t(roleI18nKeyMap[value], { ns: 'common' }) })}
-
-
- - -
-
-
{ - onChange('normal') - setOpen(false) - }} - > -
-
{t('members.normal', { ns: 'common' })}
-
{t('members.normalTip', { ns: 'common' })}
- {value === 'normal' && ( -
- )} -
-
-
{ - onChange('editor') - setOpen(false) - }} - > -
-
{t('members.editor', { ns: 'common' })}
-
{t('members.editorTip', { ns: 'common' })}
- {value === 'editor' && ( -
- )} -
-
-
{ - onChange('admin') - setOpen(false) - }} - > -
-
{t('members.admin', { ns: 'common' })}
-
{t('members.adminTip', { ns: 'common' })}
- {value === 'admin' && ( -
- )} -
-
- {datasetOperatorEnabled && ( +
+
{t('members.normal', { ns: 'common' })}
+
{t('members.normalTip', { ns: 'common' })}
+ {value === 'normal' && (
{ - onChange('dataset_operator') - setOpen(false) - }} - > -
-
{t('members.datasetOperator', { ns: 'common' })}
-
{t('members.datasetOperatorTip', { ns: 'common' })}
- {value === 'dataset_operator' && ( -
- )} -
-
+ data-testid="role-option-check" + className="i-custom-vender-line-general-check absolute left-0 top-0.5 h-4 w-4 text-text-accent" + /> )}
- -
- +
{ + onChange('editor') + setOpen(false) + }} + > +
+
{t('members.editor', { ns: 'common' })}
+
{t('members.editorTip', { ns: 'common' })}
+ {value === 'editor' && ( +
+ )} +
+
+
{ + onChange('admin') + setOpen(false) + }} + > +
+
{t('members.admin', { ns: 'common' })}
+
{t('members.adminTip', { ns: 'common' })}
+ {value === 'admin' && ( +
+ )} +
+
+ {datasetOperatorEnabled && ( +
{ + onChange('dataset_operator') + setOpen(false) + }} + > +
+
{t('members.datasetOperator', { ns: 'common' })}
+
{t('members.datasetOperatorTip', { ns: 'common' })}
+ {value === 'dataset_operator' && ( +
+ )} +
+
+ )} +
+ + ) } diff --git a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx index 389db4a42d..dbabb384a2 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx @@ -1,15 +1,10 @@ import type { InvitationResult } from '@/models/common' -import { XMarkIcon } from '@heroicons/react/24/outline' -import { CheckCircleIcon } from '@heroicons/react/24/solid' -import { RiQuestionLine } from '@remixicon/react' -import { noop } from 'es-toolkit/function' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' -import Modal from '@/app/components/base/modal' -import Tooltip from '@/app/components/base/tooltip' +import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@/app/components/base/ui/dialog' +import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import { IS_CE_EDITION } from '@/config' -import s from './index.module.css' import InvitationLink from './invitation-link' export type SuccessInvitationResult = Extract @@ -29,8 +24,18 @@ const InvitedModal = ({ const failedInvitationResults = useMemo(() => invitationResults?.filter(item => item.status !== 'success') as FailedInvitationResult[], [invitationResults]) return ( -
- + { + if (!open) + onCancel() + }} + > + +
- +
-
-
{t('members.invitationSent', { ns: 'common' })}
+ {t('members.invitationSent', { ns: 'common' })} {!IS_CE_EDITION && (
{t('members.invitationSentTip', { ns: 'common' })}
)} @@ -54,7 +58,7 @@ const InvitedModal = ({ !!successInvitationResults.length && ( <> -
{t('members.invitationLink', { ns: 'common' })}
+
{t('members.invitationLink', { ns: 'common' })}
{successInvitationResults.map(item => )} @@ -64,18 +68,23 @@ const InvitedModal = ({ !!failedInvitationResults.length && ( <> -
{t('members.failedInvitationEmails', { ns: 'common' })}
+
{t('members.failedInvitationEmails', { ns: 'common' })}
{ failedInvitationResults.map(item => (
- -
- {item.email} - -
+ + + {item.email} +
+
+ )} + /> + + {item.message} +
), @@ -97,8 +106,8 @@ const InvitedModal = ({ {t('members.ok', { ns: 'common' })}
- -
+
+
) } diff --git a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx index 8f55660fd8..0c5874c4dc 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx @@ -4,7 +4,7 @@ import copy from 'copy-to-clipboard' import { t } from 'i18next' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' -import Tooltip from '@/app/components/base/tooltip' +import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import s from './index.module.css' type IInvitationLinkProps = { @@ -38,20 +38,28 @@ const InvitationLink = ({
- -
{value.url}
+ + {value.url}
} + /> + + {isCopied ? t('copied', { ns: 'appApi' }) : t('copy', { ns: 'appApi' })} +
- -
-
-
-
+ + +
+
+
+ )} + /> + + {isCopied ? t('copied', { ns: 'appApi' }) : t('copy', { ns: 'appApi' })} +
diff --git a/web/app/components/header/account-setting/members-page/operation/index.tsx b/web/app/components/header/account-setting/members-page/operation/index.tsx index 35c4676d5f..e2b14b9078 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.tsx @@ -102,7 +102,7 @@ const Operation = ({
- +
{ diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx index 099a146866..6a2af9ffdb 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx @@ -141,6 +141,7 @@ const TransferOwnershipModal = ({ onClose, show }: Props) => {
= ({
- +
({ ), })) -// Mock portal components to avoid async/jsdom issues (consistent with sibling tests) +// Mock portal components to avoid async test DOM issues (consistent with sibling tests) vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean, onOpenChange: (open: boolean) => void }) => (
diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx index 4025e307f1..536de9bbdf 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx @@ -116,7 +116,7 @@ const AddCustomModel = ({ > {renderTrigger(open)} - +
{ @@ -136,7 +136,7 @@ const AddCustomModel = ({ modelName={model.model} />
{model.model} @@ -148,7 +148,7 @@ const AddCustomModel = ({ { !notAllowCustomCredential && (
{ handleOpenModalForAddNewCustomModel() setOpen(false) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx index e2f859b09d..15101a6542 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx @@ -164,7 +164,7 @@ const Authorized = ({ > {renderTrigger(mergedIsOpen)} - +
{ popupTitle && ( -
+
{popupTitle}
) @@ -218,7 +218,7 @@ const Authorized = ({ } : undefined, )} - className="system-xs-medium flex h-[40px] cursor-pointer items-center px-3 text-text-accent-light-mode-only" + className="flex h-[40px] cursor-pointer items-center px-3 text-text-accent-light-mode-only system-xs-medium" > {t('modelProvider.auth.addModelCredential', { ns: 'common' })} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx index 52513e7aeb..dd1d8e6eb9 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx @@ -53,14 +53,14 @@ const CredentialSelector = ({ triggerPopupSameWidth > !disabled && setOpen(v => !v)}> -
+
{ selectedCredential && (
{ !selectedCredential.addNewCredential && } -
{selectedCredential.credential_name}
+
{selectedCredential.credential_name}
{ selectedCredential.from_enterprise && ( Enterprise @@ -71,13 +71,13 @@ const CredentialSelector = ({ } { !selectedCredential && ( -
{t('modelProvider.auth.selectModelCredential', { ns: 'common' })}
+
{t('modelProvider.auth.selectModelCredential', { ns: 'common' })}
) }
- +
{ @@ -98,7 +98,7 @@ const CredentialSelector = ({ { !notAllowAddNewCredential && (
diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx index 496058bf9b..5c8d5e7489 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/__tests__/index.spec.tsx @@ -1,7 +1,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import ModelParameterModal from '../index' -let isAPIKeySet = true let parameterRules: Array> | undefined = [ { name: 'temperature', @@ -40,7 +39,7 @@ let activeTextGenerationModelList: Array> = [ vi.mock('@/context/provider-context', () => ({ useProviderContext: () => ({ - isAPIKeySet, + isAPIKeySet: true, }), })) @@ -50,6 +49,7 @@ vi.mock('@/service/use-common', () => ({ data: parameterRules, }, isLoading: isRulesLoading, + isPending: isRulesLoading, }), })) @@ -62,12 +62,18 @@ vi.mock('../../hooks', () => ({ })) vi.mock('../parameter-item', () => ({ - default: ({ parameterRule, onChange, onSwitch }: { + default: ({ parameterRule, onChange, onSwitch, nodesOutputVars, availableNodes }: { parameterRule: { name: string, label: { en_US: string } } onChange: (v: number) => void onSwitch: (checked: boolean, val: unknown) => void + nodesOutputVars?: unknown[] + availableNodes?: unknown[] }) => ( -
+
{parameterRule.label.en_US} @@ -119,7 +125,6 @@ describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - isAPIKeySet = true isRulesLoading = false parameterRules = [ { @@ -233,6 +238,26 @@ describe('ModelParameterModal', () => { expect(screen.getByTestId('model-selector')).toBeInTheDocument() }) + it('should pass nodesOutputVars and availableNodes to ParameterItem', () => { + const mockNodesOutputVars = [{ nodeId: 'n1', title: 'Node', vars: [] }] + const mockAvailableNodes = [{ id: 'n1', data: { title: 'Node', type: 'llm' } }] + + render( + , + ) + + fireEvent.click(screen.getByText('Open Settings')) + + const paramEl = screen.getByTestId('param-temperature') + expect(paramEl).toHaveAttribute('data-has-nodes-output-vars', 'true') + expect(paramEl).toHaveAttribute('data-has-available-nodes', 'true') + }) + it('should support custom triggers, workflow mode, and missing default model values', async () => { render( ({ useLanguage: () => 'en_US', })) -vi.mock('@/app/components/base/slider', () => ({ - default: ({ onChange }: { onChange: (v: number) => void }) => ( - +vi.mock('@/app/components/base/ui/slider', () => ({ + Slider: ({ onValueChange }: { onValueChange: (v: number) => void }) => ( + ), })) @@ -18,6 +23,29 @@ vi.mock('@/app/components/base/tag-input', () => ({ ), })) +let promptEditorOnChange: ((text: string) => void) | undefined +let capturedWorkflowNodesMap: Record | undefined + +vi.mock('@/app/components/base/prompt-editor', () => ({ + default: ({ value, onChange, workflowVariableBlock }: { + value: string + onChange: (text: string) => void + workflowVariableBlock?: { + show: boolean + variables: NodeOutPutVar[] + workflowNodesMap?: Record + } + }) => { + promptEditorOnChange = onChange + capturedWorkflowNodesMap = workflowVariableBlock?.workflowNodesMap + return ( +
+ {value} +
+ ) + }, +})) + describe('ParameterItem', () => { const createRule = (overrides: Partial = {}): ModelParameterRule => ({ name: 'temp', @@ -30,9 +58,10 @@ describe('ParameterItem', () => { beforeEach(() => { vi.clearAllMocks() + promptEditorOnChange = undefined + capturedWorkflowNodesMap = undefined }) - // Float tests it('should render float controls and clamp numeric input to max', () => { const onChange = vi.fn() render() @@ -50,7 +79,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(0.1) }) - // Int tests it('should render int controls and clamp numeric input', () => { const onChange = vi.fn() render() @@ -75,22 +103,17 @@ describe('ParameterItem', () => { it('should render int input without slider if min or max is missing', () => { render() expect(screen.queryByRole('slider')).not.toBeInTheDocument() - // No max -> precision step expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0') }) - // Slider events (uses generic value mock for slider) it('should handle slide change and clamp values', () => { const onChange = vi.fn() render() - // Test that the actual slider triggers the onChange logic correctly - // The implementation of Slider uses onChange(val) directly via the mock fireEvent.click(screen.getByTestId('slider-btn')) expect(onChange).toHaveBeenCalledWith(2) }) - // Text & String tests it('should render exact string input and propagate text changes', () => { const onChange = vi.fn() render() @@ -109,21 +132,17 @@ describe('ParameterItem', () => { it('should render select for string with options', () => { render() - // Select renders the selected value in the trigger expect(screen.getByText('a')).toBeInTheDocument() }) - // Tag Tests it('should render tag input for tag type', () => { const onChange = vi.fn() render() expect(screen.getByText('placeholder')).toBeInTheDocument() - // Trigger mock tag input fireEvent.click(screen.getByTestId('tag-input')) expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2']) }) - // Boolean tests it('should render boolean radios and update value on click', () => { const onChange = vi.fn() render() @@ -131,7 +150,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(false) }) - // Switch tests it('should call onSwitch with current value when optional switch is toggled off', () => { const onSwitch = vi.fn() render() @@ -146,7 +164,6 @@ describe('ParameterItem', () => { expect(screen.queryByRole('switch')).not.toBeInTheDocument() }) - // Default Value Fallbacks (rendering without value) it('should use default values if value is undefined', () => { const { rerender } = render() expect(screen.getByRole('spinbutton')).toHaveValue(0.5) @@ -158,26 +175,102 @@ describe('ParameterItem', () => { expect(screen.getByText('True')).toBeInTheDocument() expect(screen.getByText('False')).toBeInTheDocument() - // Without default - rerender() // min is 0 by default in createRule + rerender() expect(screen.getByRole('spinbutton')).toHaveValue(0) }) - // Input Blur it('should reset input to actual bound value on blur', () => { render() const input = screen.getByRole('spinbutton') - // change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state) - // Actually our test fires a change so localValue = 1, then blur sets it fireEvent.change(input, { target: { value: '5' } }) fireEvent.blur(input) expect(input).toHaveValue(1) }) - // Unsupported it('should render no input for unsupported parameter type', () => { render() expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument() }) + + describe('workflow variable reference', () => { + const mockNodesOutputVars: NodeOutPutVar[] = [ + { nodeId: 'node1', title: 'LLM Node', vars: [] }, + ] + const mockAvailableNodes: Node[] = [ + { id: 'node1', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'LLM Node', type: BlockEnum.LLM } } as Node, + { id: 'start', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'Start', type: BlockEnum.Start } } as Node, + ] + + it('should build workflowNodesMap and render PromptEditor for string type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + expect(capturedWorkflowNodesMap!.node1.title).toBe('LLM Node') + expect(capturedWorkflowNodesMap!.sys.title).toBe('workflow.blocks.start') + expect(capturedWorkflowNodesMap!.sys.type).toBe(BlockEnum.Start) + + promptEditorOnChange?.('updated text') + expect(onChange).toHaveBeenCalledWith('updated text') + }) + + it('should build workflowNodesMap and render PromptEditor for text type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + + promptEditorOnChange?.('new long text') + expect(onChange).toHaveBeenCalledWith('new long text') + }) + + it('should fall back to plain input when not in workflow mode for string type', () => { + render( + , + ) + + expect(screen.queryByTestId('prompt-editor')).not.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + + it('should return undefined workflowNodesMap when not in workflow mode', () => { + render( + , + ) + + expect(capturedWorkflowNodesMap).toBeUndefined() + }) + }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index 6b4018e2aa..ccb2c67a0d 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -9,6 +9,10 @@ import type { } from '../declarations' import type { ParameterValue } from './parameter-item' import type { TriggerProps } from './trigger' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' @@ -45,6 +49,8 @@ export type ModelParameterModalProps = { readonly?: boolean isInWorkflow?: boolean scope?: string + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } const ModelParameterModal: FC = ({ @@ -61,11 +67,18 @@ const ModelParameterModal: FC = ({ renderTrigger, readonly, isInWorkflow, + nodesOutputVars, + availableNodes, }) => { const { t } = useTranslation() const [open, setOpen] = useState(false) const settingsIconRef = useRef(null) - const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId) + const { + data: parameterRulesData, + isPending, + isLoading, + } = useModelParameterRules(provider, modelId) + const isRulesLoading = isPending || isLoading const { currentProvider, currentModel, @@ -191,7 +204,7 @@ const ModelParameterModal: FC = ({ }
{ - isLoading + isRulesLoading ?
: ( [ @@ -205,6 +218,8 @@ const ModelParameterModal: FC = ({ onChange={v => handleParamChange(parameter.name, v)} onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)} isInWorkflow={isInWorkflow} + nodesOutputVars={nodesOutputVars} + availableNodes={availableNodes} /> )) ) @@ -213,7 +228,7 @@ const ModelParameterModal: FC = ({ ) } { - !parameterRules.length && isLoading && ( + !parameterRules.length && isRulesLoading && (
) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 86fb6d81d0..162e39d162 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -1,11 +1,18 @@ import type { ModelParameterRule } from '../declarations' -import { useEffect, useRef, useState } from 'react' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' +import { useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import PromptEditor from '@/app/components/base/prompt-editor' import Radio from '@/app/components/base/radio' -import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import TagInput from '@/app/components/base/tag-input' import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' +import { Slider } from '@/app/components/base/ui/slider' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' +import { BlockEnum } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' import { useLanguage } from '../hooks' import { isNullOrUndefined } from '../utils' @@ -18,18 +25,43 @@ type ParameterItemProps = { onChange?: (value: ParameterValue) => void onSwitch?: (checked: boolean, assignValue: ParameterValue) => void isInWorkflow?: boolean + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } + function ParameterItem({ parameterRule, value, onChange, onSwitch, isInWorkflow, + nodesOutputVars, + availableNodes = [], }: ParameterItemProps) { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) + const workflowNodesMap = useMemo(() => { + if (!isInWorkflow || !availableNodes.length) + return undefined + + return availableNodes.reduce>>((acc, node) => { + acc[node.id] = { + title: node.data.title, + type: node.data.type, + } + if (node.data.type === BlockEnum.Start) { + acc.sys = { + title: t('blocks.start', { ns: 'workflow' }), + type: BlockEnum.Start, + } + } + return acc + }, {}) + }, [availableNodes, isInWorkflow, t]) + const getDefaultValue = () => { let defaultValue: ParameterValue @@ -46,6 +78,7 @@ function ParameterItem({ } const renderValue = value ?? localValue ?? getDefaultValue() + const sliderLabel = parameterRule.label[language] || parameterRule.label.en_US const handleInputChange = (newValue: ParameterValue) => { setLocalValue(newValue) @@ -138,7 +171,8 @@ function ParameterItem({ min={parameterRule.min} max={parameterRule.max} step={step} - onChange={handleSlideChange} + onValueChange={handleSlideChange} + aria-label={sliderLabel} /> )} )} + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return ( + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return (